diff --git a/package.json b/package.json index 77eefb4..7bdbbad 100644 --- a/package.json +++ b/package.json @@ -23,8 +23,10 @@ "discord.js": "^12.1.1", "dotenv": "^8.1.0", "eslint-plugin-jest": "^24.3.5", - "pg": "^7.12.1", + "pg": "^8.0.3", + "pg-mem": "^1.9.6", "postgrator": "^4.0.1", + "sqlutils": "^1.2.1", "typescript": "^4.1.5" }, "devDependencies": { diff --git a/src/app.ts b/src/app.ts index d2fc560..fc278e9 100644 --- a/src/app.ts +++ b/src/app.ts @@ -9,7 +9,7 @@ import Discord = require('discord.js'); import fs = require('fs'); -import {prefix, ID, toID, pgPool} from './common'; +import {prefix, ID, toID, database} from './common'; import {BaseCommand, BaseMonitor, DiscordChannel} from './command_base'; import {updateDatabase} from './database_version_control'; import * as child_process from 'child_process'; @@ -54,14 +54,12 @@ const userlist = new Set(); export async function verifyData(data: Discord.Message | IDatabaseInsert) { if (lockdown) return; - let worker = null; // Server if (data.guild && !servers.has(data.guild.id)) { - if (!worker) worker = await pgPool.connect(); - const res = await worker.query('SELECT * FROM servers WHERE serverid = $1', [data.guild.id]); - if (!res.rows.length) { - await worker.query( + const res = await database.queryWithResults('SELECT * FROM servers WHERE serverid = $1', [data.guild.id]); + if (!res.length) { + await database.query( 'INSERT INTO servers (serverid, servername, logchannel, sticky) VALUES ($1, $2, $3, $4)', [data.guild.id, data.guild.name, null, []] ); @@ -72,10 +70,9 @@ export async function verifyData(data: Discord.Message | IDatabaseInsert) { // Channel if (data.guild && data.channel && ['text', 'news'].includes(data.channel.type) && !channels.has(data.channel.id)) { const channel = (data.channel as Discord.TextChannel | Discord.NewsChannel); - if (!worker) worker = await pgPool.connect(); - const res = await worker.query('SELECT * FROM channels WHERE channelid = $1', [channel.id]); - if (!res.rows.length) { - await worker.query( + const res = await database.queryWithResults('SELECT * FROM channels WHERE channelid = $1', [channel.id]); + if (!res.length) { + await database.query( 'INSERT INTO channels (channelid, channelname, serverid) VALUES ($1, $2, $3)', [channel.id, channel.name, data.guild.id] ); @@ -85,10 +82,9 @@ export async function verifyData(data: Discord.Message | IDatabaseInsert) { // User if (data.author && !users.has(data.author.id)) { - if (!worker) worker = await pgPool.connect(); - const res = await worker.query('SELECT * FROM users WHERE userid = $1', [data.author.id]); - if (!res.rows.length) { - await worker.query( + const res = await database.queryWithResults('SELECT * FROM users WHERE userid = $1', [data.author.id]); + if (!res.length) { + await database.query( 'INSERT INTO users (userid, name, discriminator) VALUES ($1, $2, $3)', [data.author.id, data.author.username, data.author.discriminator] ); @@ -102,13 +98,12 @@ export async function verifyData(data: Discord.Message | IDatabaseInsert) { await data.guild.members.fetch(); const userInServer = data.guild.members.cache.has(data.author.id); if (userInServer) { - if (!worker) worker = await pgPool.connect(); - const res = await worker.query( + const res = await database.queryWithResults( 'SELECT * FROM userlist WHERE serverid = $1 AND userid = $2', [data.guild.id, data.author.id] ); - if (!res.rows.length) { - await worker.query( + if (!res.length) { + await database.query( 'INSERT INTO userlist (serverid, userid, boosting, sticky) VALUES ($1, $2, $3, $4)', [data.guild.id, data.author.id, null, []] ); @@ -116,8 +111,6 @@ export async function verifyData(data: Discord.Message | IDatabaseInsert) { userlist.add(data.guild.id + ',' + data.author.id); } } - - if (worker) worker.release(); } export const client = new Discord.Client(); @@ -199,8 +192,6 @@ client.on('message', (m) => void (async msg => { } catch (e) { await onError(e, 'A chat monitor crashed: '); } - // Release any workers regardless of the result - monitor.releaseWorker(true); } return; } @@ -225,8 +216,6 @@ client.on('message', (m) => void (async msg => { '\u274C - An error occured while trying to run your command. The error has been logged, and we will fix it soon.' ); } - // Release any workers regardless of the result - cmd.releaseWorker(true); })(m)); // Setup crash handlers diff --git a/src/command_base.ts b/src/command_base.ts index a177a4d..4c17195 100644 --- a/src/command_base.ts +++ b/src/command_base.ts @@ -5,8 +5,7 @@ * execution in general. */ import Discord = require('discord.js'); -import {prefix, toID, pgPool} from './common'; -import {PoolClient} from 'pg'; +import {prefix, toID, database} from './common'; import {client, verifyData} from './app'; export type DiscordChannel = Discord.TextChannel | Discord.NewsChannel; @@ -32,7 +31,6 @@ export abstract class BaseCommand { protected author: Discord.User; protected channel: DiscordChannel; protected guild: Discord.Guild | null; - protected worker: PoolClient | null; protected isMonitor: boolean; /** @@ -46,7 +44,6 @@ export abstract class BaseCommand { this.author = message.author; this.channel = (message.channel as DiscordChannel); this.guild = message.guild; - this.worker = null; this.isMonitor = false; } @@ -293,22 +290,10 @@ export abstract class BaseCommand { */ protected async sendLog(msg: string | Discord.MessageEmbed): Promise { if (!toID(msg) || !this.guild) return; - const res = await pgPool.query('SELECT logchannel FROM servers WHERE serverid = $1', [this.guild.id]); - const channel = this.getChannel(res.rows[0].logchannel, false, false); + const res = await database.queryWithResults('SELECT logchannel FROM servers WHERE serverid = $1', [this.guild.id]); + const channel = this.getChannel(res[0].logchannel, false, false); if (channel) await channel.send(msg); } - - /** - * Used by app.ts to release a PoolClient in the event - * a command using one crashes - */ - releaseWorker(warn = false): void { - if (this.worker) { - if (warn) console.warn(`Releasing PG worker for ${this.isMonitor ? 'monitor' : 'command'}: ${this.cmd}`); - this.worker.release(); - this.worker = null; - } - } } /** diff --git a/src/commands/activity.ts b/src/commands/activity.ts index b3f2aee..6f495a0 100644 --- a/src/commands/activity.ts +++ b/src/commands/activity.ts @@ -4,7 +4,7 @@ * channel activity monitors. */ import Discord = require('discord.js'); -import {prefix, toID, pgPool} from '../common'; +import {prefix, toID, database} from '../common'; import {BaseCommand, ReactionPageTurner, DiscordChannel, IAliasList} from '../command_base'; const ENGLISH_MONTH_NAMES = [ @@ -200,9 +200,7 @@ export class Leaderboard extends BaseCommand { } query += ' GROUP BY u.name, u.discriminator ORDER BY SUM(l.lines) desc;'; - const res = await pgPool.query(query, args); - - return res.rows; + return database.queryWithResults(query, args); } async execute() { @@ -269,9 +267,7 @@ export class ChannelLeaderboard extends BaseCommand { } query += ' GROUP BY ch.channelname ORDER BY SUM(cl.lines) desc;'; - const res = await pgPool.query(query, args); - - return res.rows; + return database.queryWithResults(query, args); } async execute() { @@ -333,9 +329,7 @@ export class Linecount extends BaseCommand { let query = `SELECT ${key ? key + ' AS time, ' : ''}SUM(l.lines) FROM lines l WHERE l.serverid = $1 AND l.userid = $2`; if (key) query += ` GROUP BY ${key} ORDER BY ${key} desc;`; const args = [this.guild.id, id]; - const res = await pgPool.query(query, args); - - return res.rows; + return database.queryWithResults(query, args); } async execute() { @@ -401,9 +395,7 @@ export class ChannelLinecount extends BaseCommand { query += ' WHERE ch.serverid = $1 AND ch.channelid = $2'; if (key) query += ` GROUP BY ${key} ORDER BY ${key} desc;`; const args = [this.guild.id, id]; - const res = await pgPool.query(query, args); - - return res.rows; + return database.queryWithResults(query, args); } async execute() { diff --git a/src/commands/boosts.ts b/src/commands/boosts.ts index 55ba049..79853b4 100644 --- a/src/commands/boosts.ts +++ b/src/commands/boosts.ts @@ -4,17 +4,15 @@ */ import Discord = require('discord.js'); import {client, verifyData} from '../app'; -import {prefix, pgPool} from '../common'; +import {prefix, database} from '../common'; import {BaseCommand, ReactionPageTurner, DiscordChannel} from '../command_base'; async function updateBoosters() { - const worker = await pgPool.connect(); - for (const [guildId, guild] of client.guilds.cache) { - const res = await worker.query('SELECT userid FROM userlist WHERE serverid = $1 AND boosting IS NOT NULL', [guildId]); - const boosting = res.rows.map(r => r.userid); - const logchannelResult = await pgPool.query('SELECT logchannel FROM servers WHERE serverid = $1', [guildId]); - const logChannel = client.channels.cache.get(logchannelResult.rows[0].logchannel) as DiscordChannel; + const res = await database.queryWithResults('SELECT userid FROM userlist WHERE serverid = $1 AND boosting IS NOT NULL', [guildId]); + const boosting = res.map(r => r.userid); + const logchannelResult = await database.queryWithResults('SELECT logchannel FROM servers WHERE serverid = $1', [guildId]); + const logChannel = client.channels.cache.get(logchannelResult[0].logchannel) as DiscordChannel; await guild.members.fetch(); for (const [id, gm] of guild.members.cache) { @@ -30,23 +28,23 @@ async function updateBoosters() { }); // Check if booster is in users table/userlist - if (!(await worker.query('SELECT userid FROM users WHERE userid = $1', [id])).rows.length) { - await worker.query( + if (!(await database.queryWithResults('SELECT userid FROM users WHERE userid = $1', [id])).length) { + await database.query( 'INSERT INTO users (userid, name, discriminator) VALUES ($1, $2, $3)', [gm.user.id, gm.user.username, gm.user.discriminator] ); } - const users = await worker.query('SELECT userid FROM userlist WHERE userid = $1 AND serverid = $2', [id, guildId]); - if (!users.rows.length) { + const users = await database.queryWithResults('SELECT userid FROM userlist WHERE userid = $1 AND serverid = $2', [id, guildId]); + if (!users.length) { // Insert with update - await worker.query( + await database.query( 'INSERT INTO userlist (serverid, userid, boosting) VALUES ($1, $2, $3)', [guildId, id, gm.premiumSince] ); } else { // Just update - await worker.query( + await database.query( 'UPDATE userlist SET boosting = $1 WHERE serverid = $2 AND userid = $3', [gm.premiumSince, guildId, id] ); @@ -54,21 +52,19 @@ async function updateBoosters() { await logChannel?.send(`<@${id}> has started boosting!`); } else { if (!boosting.includes(id)) continue; // Was not bosting before - await worker.query('UPDATE userlist SET boosting = NULL WHERE serverid = $1 AND userid = $2', [guildId, id]); + await database.query('UPDATE userlist SET boosting = NULL WHERE serverid = $1 AND userid = $2', [guildId, id]); await logChannel?.send(`<@${id}> is no longer boosting.`); boosting.splice(boosting.indexOf(id), 1); } } // Anyone left in boosting left the server and is no longer boosting - for (const desterter of boosting) { - await worker.query('UPDATE userlist SET boosting = NULL WHERE serverid = $1 AND userid = $2', [guildId, desterter]); - await logChannel?.send(`<@${desterter}> is no longer boosting because they left the server.`); + for (const deserter of boosting) { + await database.query('UPDATE userlist SET boosting = NULL WHERE serverid = $1 AND userid = $2', [guildId, deserter]); + await logChannel?.send(`<@${deserter}> is no longer boosting because they left the server.`); } } - worker.release(); - // Schedule next boost check const nextCheck = new Date(); nextCheck.setDate(nextCheck.getDate() + 1); @@ -133,17 +129,20 @@ export class Boosters extends BaseCommand { } async execute() { - if (!this.guild) return this.errorReply('This command is not mean\'t to be used in PMs.'); + if (!this.guild) return this.errorReply('This command is not meant to be used in PMs.'); if (!(await this.can('MANAGE_ROLES'))) return this.errorReply('Access Denied.'); - const res = await pgPool.query('SELECT u.name, u.discriminator, ul.boosting ' + + const res = await database.queryWithResults( + 'SELECT u.name, u.discriminator, ul.boosting ' + 'FROM users u ' + 'INNER JOIN userlist ul ON u.userid = ul.userid ' + 'INNER JOIN servers s ON s.serverid = ul.serverid ' + 'WHERE s.serverid = $1 AND ul.boosting IS NOT NULL ' + - 'ORDER BY ul.boosting', [this.guild.id]); + 'ORDER BY ul.boosting', + [this.guild.id] + ); - const page = new BoostPage(this.channel, this.author, this.guild, res.rows); + const page = new BoostPage(this.channel, this.author, this.guild, res); await page.initialize(this.channel); } diff --git a/src/commands/dev.ts b/src/commands/dev.ts index a5ced61..1ea8b79 100644 --- a/src/commands/dev.ts +++ b/src/commands/dev.ts @@ -4,7 +4,7 @@ */ import Discord = require('discord.js'); import {shutdown} from '../app'; -import {prefix, pgPool} from '../common'; +import {prefix, database} from '../common'; import {BaseCommand, IAliasList} from '../command_base'; import * as child_process from 'child_process'; let updateLock = false; @@ -55,13 +55,13 @@ export class Query extends BaseCommand { async execute() { if (!(await this.can('EVAL'))) return this.errorReply('You do not have permission to do that.'); - pgPool.query(this.target, (err, res) => { - if (err) { - void this.sendCode(`An error occured: ${err.toString()}`); - } else { - void this.sendCode(this.formatResponse(res.rows)); - } - }); + + try { + const res = await database.queryWithResults(this.target, undefined); + await this.sendCode(this.formatResponse(res)); + } catch (err) { + await this.sendCode(`An error occured: ${err.toString()}`); + } } private formatResponse(rows: any[]): string { @@ -127,7 +127,7 @@ export class Shutdown extends BaseCommand { }, 10000); // empty the pool of database workers - await pgPool.end(); + await database.destroy(); // exit process.exit(); diff --git a/src/commands/moderation.ts b/src/commands/moderation.ts index d463e81..44e6c7f 100644 --- a/src/commands/moderation.ts +++ b/src/commands/moderation.ts @@ -4,7 +4,7 @@ * and administrators. */ import Discord = require('discord.js'); -import {prefix, toID, pgPool} from '../common'; +import {prefix, toID, database} from '../common'; import {BaseCommand} from '../command_base'; import {client} from '../app'; @@ -37,7 +37,7 @@ export class Whois extends BaseCommand { } async execute() { - if (!this.guild) return this.errorReply('This command is not mean\'t to be used in PMs.'); + if (!this.guild) return this.errorReply('This command is not meant to be used in PMs.'); if (!(await this.can('KICK_MEMBERS'))) return this.errorReply('Access Denied.'); const user = this.getUser(this.target); @@ -98,13 +98,13 @@ export class WhoHas extends BaseCommand { } async execute() { - if (!this.guild) return this.errorReply('This command is not mean\'t to be used in PMs.'); + if (!this.guild) return this.errorReply('This command is not meant to be used in PMs.'); if (!(await this.can('KICK_MEMBERS'))) return this.errorReply('Access Denied.'); if (!this.target.trim()) return this.reply(WhoHas.help()); const role = await this.getRole(this.target, true); - if (!role) return this.errorReply(`The role "${this.target}" was not found. Role names are Case Sensetive - Make sure your typing the role name exactly as it appears.`); + if (!role) return this.errorReply(`The role "${this.target}" was not found. Role names are case sensitive: make sure you're typing the role name exactly as it appears.`); const embed: Discord.MessageEmbedOptions = { color: 0x6194fd, @@ -160,9 +160,9 @@ abstract class StickyCommand extends BaseCommand { return true; } - async massStickyUpdate(role: Discord.Role, unsticky = false): Promise { + async massStickyUpdate(role: Discord.Role, unsticky = false): Promise<{statement: string; args: string[]} | null> { if (!this.guild || this.guild.id !== role.guild.id) throw new Error('Guild missmatch in sticky command'); - if (!role.members.size) return; // No members have this role, so no database update needed + if (!role.members.size) return null; // No members have this role, so no database update needed await this.guild.members.fetch(); let query = `UPDATE userlist SET sticky = ${unsticky ? 'array_remove' : 'array_append'}(sticky, $1) WHERE serverid = $2 AND userid IN (`; @@ -177,7 +177,7 @@ abstract class StickyCommand extends BaseCommand { query = query.slice(0, query.length - 2); query += ');'; - await pgPool.query(query, args); + return {statement: query, args}; } } @@ -188,7 +188,7 @@ export class Sticky extends StickyCommand { async execute() { if (!toID(this.target)) return this.reply(Sticky.help()); - if (!this.guild) return this.errorReply('This command is not mean\'t to be used in PMs.'); + if (!this.guild) return this.errorReply('This command is not meant to be used in PMs.'); if (!(await this.can('MANAGE_ROLES'))) return this.errorReply('Access Denied'); const bot = this.guild.me ? this.guild.me.user : null; if (!bot) throw new Error('Bot user not found.'); @@ -216,34 +216,25 @@ export class Sticky extends StickyCommand { } // Validate @role is not already sticky (database query) - this.worker = await pgPool.connect(); - const res = await this.worker.query('SELECT sticky FROM servers WHERE serverid = $1', [this.guild.id]); - if (!res.rows.length) { + const res = await database.queryWithResults('SELECT sticky FROM servers WHERE serverid = $1', [this.guild.id]); + if (!res.length) { throw new Error(`Unable to find sticky roles in database for guild: ${this.guild.name} (${this.guild.id})`); } - const stickyRoles: string[] = res.rows[0].sticky; + const stickyRoles: string[] = res[0].sticky; if (stickyRoles.includes(role.id)) { - this.releaseWorker(); return this.errorReply('That role is already sticky!'); } // ---VALIDATION LINE--- // Make @role sticky (database update) stickyRoles.push(role.id); - try { - await this.worker.query('BEGIN'); - await this.worker.query('UPDATE servers SET sticky = $1 WHERE serverid = $2', [stickyRoles, this.guild.id]); - // Find all users with @role and perform database update so role is now sticky for them - await this.massStickyUpdate(role); - await this.worker.query('COMMIT'); - this.releaseWorker(); - } catch (e) { - await this.worker.query('ROLLBACK'); - this.releaseWorker(); - throw e; - } - // Return success message + + const queries = [{statement: 'UPDATE servers SET sticky = $1 WHERE serverid = $2', args: [stickyRoles, this.guild.id]}]; + const stickyUpdate = await this.massStickyUpdate(role); + if (stickyUpdate) queries.push(stickyUpdate); + await database.withinTransaction(queries); + await this.reply(`The role "${role.name}" is now sticky! Members who leave and rejoin the server with this role will have it reassigned automatically.`); } @@ -255,12 +246,12 @@ export class Sticky extends StickyCommand { static async init(): Promise { // This init is for all four sticky role commands - const res = await pgPool.query('SELECT serverid, sticky FROM servers'); - if (!res.rows.length) return; // No servers? + const res = await database.queryWithResults('SELECT serverid, sticky FROM servers'); + if (!res.length) return; // No servers? - for (const {sticky: stickyRoles, serverid} of res.rows) { + for (const {sticky: stickyRoles, serverid} of res) { // Get list of users and their sticky roles - const serverRes = await pgPool.query('SELECT userid, sticky FROM userlist WHERE serverid = $1', [serverid]); + const serverRes = await database.queryWithResults('SELECT userid, sticky FROM userlist WHERE serverid = $1', [serverid]); const server = client.guilds.cache.get(serverid); if (!server) { console.error('ERR NO SERVER FOUND'); @@ -268,7 +259,7 @@ export class Sticky extends StickyCommand { } await server.members.fetch(); - for (const {userid, sticky} of serverRes.rows) { + for (const {userid, sticky} of serverRes) { const member = server.members.cache.get(userid); if (!member) continue; // User left the server, but has not re-joined so we can't do anything but wait. // Check which of this member's roles are sticky @@ -277,7 +268,7 @@ export class Sticky extends StickyCommand { // Compare member's current sticky roles to the ones in the database. If they match, do nothing. const userStickyRoles: string[] = sticky; if (!roles.length && userStickyRoles.length) { - await pgPool.query( + await database.query( 'UPDATE userlist SET sticky = $1 WHERE serverid = $2 AND userid = $3', [roles, serverid, member.user.id] ); @@ -287,7 +278,7 @@ export class Sticky extends StickyCommand { if (roles.every(r => userStickyRoles.includes(r))) continue; // Update database with new roles - await pgPool.query( + await database.queryWithResults( 'UPDATE userlist SET sticky = $1 WHERE serverid = $2 AND userid = $3', [roles, serverid, member.user.id] ); @@ -303,7 +294,7 @@ export class Unsticky extends StickyCommand { async execute() { if (!toID(this.target)) return this.reply(Unsticky.help()); - if (!this.guild) return this.errorReply('This command is not mean\'t to be used in PMs.'); + if (!this.guild) return this.errorReply('This command is not meant to be used in PMs.'); if (!(await this.can('MANAGE_ROLES'))) return this.errorReply('Access Denied'); // Validate @role exists @@ -319,33 +310,25 @@ export class Unsticky extends StickyCommand { } // Validate @role is sticky (database query) - this.worker = await pgPool.connect(); - const res = await this.worker.query('SELECT sticky FROM servers WHERE serverid = $1', [this.guild.id]); - if (!res.rows.length) { + const res = await database.queryWithResults('SELECT sticky FROM servers WHERE serverid = $1', [this.guild.id]); + if (!res.length) { throw new Error(`Unable to find sticky roles in database for guild: ${this.guild.name} (${this.guild.id})`); } - const stickyRoles: string[] = res.rows[0].sticky; + const stickyRoles: string[] = res[0].sticky; if (!stickyRoles.includes(role.id)) { - this.releaseWorker(); return this.errorReply('That role is not sticky!'); } // ---VALIDATION LINE--- // Make @role not sticky (database update) stickyRoles.splice(stickyRoles.indexOf(role.id), 1); - try { - await this.worker.query('BEGIN'); - await this.worker.query('UPDATE servers SET sticky = $1 WHERE serverid = $2', [stickyRoles, this.guild.id]); - // Find all users with @role and perform database update so role is no longer sticky for them - await this.massStickyUpdate(role, true); - await this.worker.query('COMMIT'); - this.releaseWorker(); - } catch (e) { - await this.worker.query('ROLLBACK'); - this.releaseWorker(); - throw e; - } + + const queries = [{statement: 'UPDATE servers SET sticky = $1 WHERE serverid = $2', args: [stickyRoles, this.guild.id]}]; + const stickyUpdate = await this.massStickyUpdate(role, true); + if (stickyUpdate) queries.push(stickyUpdate); + await database.withinTransaction(queries); + // Return success message await this.reply(`The role "${role.name}" is no longer sticky.`); } @@ -363,20 +346,16 @@ export class EnableLogs extends BaseCommand { } async execute() { - if (!this.guild) return this.errorReply('This command is not mean\'t to be used in PMs.'); + if (!this.guild) return this.errorReply('This command is not meant to be used in PMs.'); if (!(await this.can('MANAGE_GUILD'))) return this.errorReply('Access Denied'); - this.worker = await pgPool.connect(); - const res = await this.worker.query('SELECT logchannel FROM servers WHERE serverid = $1', [this.guild.id]); - if (res.rows[0].logchannel) { - return this.errorReply(`This server is already set up to log to <#${res.rows[0].logchannel}>.`); + const res = await database.queryWithResults('SELECT logchannel FROM servers WHERE serverid = $1', [this.guild.id]); + if (res[0].logchannel) { + return this.errorReply(`This server is already set up to log to <#${res[0].logchannel}>.`); } - await this.worker.query('UPDATE servers SET logchannel = $1 WHERE serverid = $2', [this.channel.id, this.guild.id]); + await database.query('UPDATE servers SET logchannel = $1 WHERE serverid = $2', [this.channel.id, this.guild.id]); await this.reply('Server events will now be logged to this channel.'); - - this.worker.release(); - this.worker = null; } static help(): string { @@ -392,18 +371,14 @@ export class DisableLogs extends BaseCommand { } async execute() { - if (!this.guild) return this.errorReply('This command is not mean\'t to be used in PMs.'); + if (!this.guild) return this.errorReply('This command is not meant to be used in PMs.'); if (!(await this.can('MANAGE_GUILD'))) return this.errorReply('Access Denied'); - this.worker = await pgPool.connect(); - const res = await this.worker.query('SELECT logchannel FROM servers WHERE serverid = $1', [this.guild.id]); - if (!res.rows[0].logchannel) return this.errorReply('This server is not setup to log messages to a log channel.'); + const res = await database.queryWithResults('SELECT logchannel FROM servers WHERE serverid = $1', [this.guild.id]); + if (!res[0].logchannel) return this.errorReply('This server is not setup to log messages to a log channel.'); - await this.worker.query('UPDATE servers SET logchannel = $1 WHERE serverid = $2', [null, this.guild.id]); + await database.query('UPDATE servers SET logchannel = $1 WHERE serverid = $2', [null, this.guild.id]); await this.reply('Server events will no longer be logged to this channel.'); - - this.worker.release(); - this.worker = null; } static help(): string { diff --git a/src/commands/rmt.ts b/src/commands/rmt.ts index 35711fa..a39d019 100644 --- a/src/commands/rmt.ts +++ b/src/commands/rmt.ts @@ -4,7 +4,7 @@ * Also see src/monitors/rmt.ts */ import Discord = require('discord.js'); -import {prefix, toID, pgPool} from '../common'; +import {prefix, toID, database} from '../common'; import {BaseCommand, DiscordChannel, IAliasList, ReactionPageTurner} from '../command_base'; export const aliases: IAliasList = { @@ -180,25 +180,21 @@ export class AddTeamRater extends RmtCommand { channel: channel, }); - this.worker = await pgPool.connect(); - // Ensure this user isnt already a rater for this format - const res = await this.worker.query( + const res = await database.queryWithResults( 'SELECT * FROM teamraters WHERE userid = $1 AND format = $2 AND channelid = $3', [user.id, format, channel.id] ); - if (res.rows.length) { + if (res.length) { // This user is already a rater for this format - this.releaseWorker(); return this.errorReply(`${user} is already a team rater for ${format} in ${channel}.`); } // Add user to team raters - await this.worker.query( + await database.query( 'INSERT INTO teamraters (userid, format, channelid) VALUES ($1, $2, $3)', [user.id, format, channel.id] ); - this.releaseWorker(); await this.reply(`${user.username} has been added as a team rater for ${format} in ${channel}`); } @@ -216,7 +212,7 @@ export class RemoveTeamRater extends RmtCommand { } async execute() { - if (!this.guild) return this.errorReply('This command is not mean\'t to be used in PMs.'); + if (!this.guild) return this.errorReply('This command is not meant to be used in PMs.'); if (!(await this.can('KICK_MEMBERS'))) return this.errorReply('Access Denied'); // Validate arguments @@ -240,17 +236,17 @@ export class RemoveTeamRater extends RmtCommand { } // Ensure this user is a rater for this format in this channel - const res = await pgPool.query( + const res = await database.queryWithResults( 'SELECT * FROM teamraters WHERE userid = $1 AND format = $2 AND channelid = $3', [user.id, format, channel.id] ); - if (!res.rows.length) { + if (!res.length) { // This user is not a rater for this format in this channel return this.errorReply(`${user.username} is not a team rater for ${format} in ${channel}.`); } // Remove user from team rater list - await pgPool.query( + await database.queryWithResults( 'DELETE FROM teamraters WHERE userid = $1 AND format = $2 AND channelid = $3', [user.id, format, channel.id] ); @@ -287,31 +283,40 @@ export class ListRaters extends RmtCommand { const channel = this.getChannel(rawChannel, true, true, allowServerName); if (!format) { - const res = await pgPool.query('SELECT DISTINCT u.name, u.discriminator, tr.format FROM teamraters tr ' + - 'INNER JOIN channels ch ON tr.channelid = ch.channelid ' + - 'INNER JOIN servers s ON ch.serverid = s.serverid ' + - 'INNER JOIN users u ON tr.userid = u.userid ' + - 'WHERE s.serverid = $1 ' + - 'ORDER BY tr.format;', [this.guild.id]); - - const page = new RaterList(this.channel, this.author, this.guild, res.rows); + const res = await database.queryWithResults( + 'SELECT DISTINCT u.name, u.discriminator, tr.format FROM teamraters tr ' + + 'INNER JOIN channels ch ON tr.channelid = ch.channelid ' + + 'INNER JOIN servers s ON ch.serverid = s.serverid ' + + 'INNER JOIN users u ON tr.userid = u.userid ' + + 'WHERE s.serverid = $1 ' + + 'ORDER BY tr.format', + [this.guild.id] + ); + + const page = new RaterList(this.channel, this.author, this.guild, res); await page.initialize(this.channel); } else if (channel) { - const res = await pgPool.query('SELECT u.name, u.discriminator, ch.channelname FROM teamraters tr ' + - 'INNER JOIN users u ON tr.userid = u.userid ' + - 'INNER JOIN channels ch ON tr.channelid = ch.channelid ' + - 'WHERE tr.format = $1 AND tr.channelid = $2 ' + - 'ORDER BY u.name, u.discriminator', [format, channel.id]); - - const page = new RaterList(this.channel, this.author, this.guild, res.rows, format); + const res = await database.queryWithResults( + 'SELECT u.name, u.discriminator, ch.channelname FROM teamraters tr ' + + 'INNER JOIN users u ON tr.userid = u.userid ' + + 'INNER JOIN channels ch ON tr.channelid = ch.channelid ' + + 'WHERE tr.format = $1 AND tr.channelid = $2 ' + + 'ORDER BY u.name, u.discriminator', + [format, channel.id] + ); + + const page = new RaterList(this.channel, this.author, this.guild, res, format); await page.initialize(this.channel); } else { - const res = await pgPool.query('SELECT DISTINCT u.name, u.discriminator FROM teamraters tr ' + - 'INNER JOIN users u ON tr.userid = u.userid ' + - 'WHERE tr.format = $1 ' + - 'ORDER BY u.name, u.discriminator', [format]); - - const page = new RaterList(this.channel, this.author, this.guild, res.rows, format); + const res = await database.queryWithResults( + 'SELECT DISTINCT u.name, u.discriminator FROM teamraters tr ' + + 'INNER JOIN users u ON tr.userid = u.userid ' + + 'WHERE tr.format = $1 ' + + 'ORDER BY u.name, u.discriminator', + [format] + ); + + const page = new RaterList(this.channel, this.author, this.guild, res, format); await page.initialize(this.channel); } } diff --git a/src/common.ts b/src/common.ts index 4cb547f..6882bbd 100644 --- a/src/common.ts +++ b/src/common.ts @@ -1,11 +1,13 @@ -import PG = require('pg'); +import * as PG from 'pg'; +import {Database, ExternalPostgresDatabase} from './lib/database'; export type ID = '' | string & {__isID: true}; // The prefix to all bot commands -export const prefix = process.env.PREFIX || '$'; +export const prefix = process.env.BOT_PREFIX || process.env.PREFIX || '$'; -export const pgPool = new PG.Pool(); +// Typed as Database because in unit tests, this will be changed to a MemoryPostgresDatabase +export const database: Database = new ExternalPostgresDatabase(new PG.Pool()); /** * toID - Turns anything into an ID (string with only lowercase alphanumeric characters) diff --git a/src/events.ts b/src/events.ts index 662cf16..4244b6b 100644 --- a/src/events.ts +++ b/src/events.ts @@ -4,13 +4,13 @@ * Exceptions are located in app.ts */ import Discord = require('discord.js'); -import {pgPool} from './common'; +import {database} from './common'; import {client} from './app'; async function getLogChannel(guild: Discord.Guild): Promise { - const res = await pgPool.query('SELECT logchannel FROM servers WHERE serverid = $1', [guild.id]); - if (!res.rows.length) return; - const channel = client.channels.cache.get(res.rows[0].logchannel); + const res = await database.queryWithResults('SELECT logchannel FROM servers WHERE serverid = $1', [guild.id]); + if (!res.length) return; + const channel = client.channels.cache.get(res[0].logchannel); if (!channel) return; return (channel as Discord.TextChannel); } @@ -178,48 +178,42 @@ client.on('guildMemberAdd', (m: Discord.GuildMember | Discord.PartialGuildMember const guild = member.guild; const bot = guild.me ? await guild.members.fetch(guild.me.user) : null; if (!bot) throw new Error('Bot user not found.'); - const worker = await pgPool.connect(); // try/catch so we don't leave a database worker out of the pool incase of an error - try { - const res = await worker.query( - 'SELECT sticky FROM userlist WHERE serverid = $1 AND userid = $2', - [guild.id, member.user.id] - ); - if (!res.rows.length) return; // User was not in database yet, which is OK here. Proably a first time join. - const sticky: string[] = res.rows[0].sticky; - if (!sticky.length) return; // User rejoined and had 0 sticky roles. - - // Re-assign sticky roles - if (!bot.hasPermission('MANAGE_ROLES')) { - // Bot can't assign roles due to lack of permissions + const res = await database.queryWithResults( + 'SELECT sticky FROM userlist WHERE serverid = $1 AND userid = $2', + [guild.id, member.user.id] + ); + if (!res.length) return; // User was not in database yet, which is OK here. Proably a first time join. + const sticky: string[] = res[0].sticky; + if (!sticky.length) return; // User rejoined and had 0 sticky roles. + + // Re-assign sticky roles + if (!bot.hasPermission('MANAGE_ROLES')) { + // Bot can't assign roles due to lack of permissions + const channel = await getLogChannel(guild); + const msg = '[WARN] Bot tried to assign sticky (persistant) roles to a user joining the server, but lacks the MANAGE_ROLES permission.'; + if (channel) await channel.send(msg); + return; + } + + await guild.roles.fetch(); + for (const roleID of sticky) { + const role = guild.roles.cache.get(roleID); + if (!role) { + // ??? Should never happen + throw new Error(`Unable to find sticky role with ID ${roleID} in server ${guild.name} (${guild.id})`); + } + + if (!(await canAssignRole(bot, role))) { + // Bot can no longer assign the role. const channel = await getLogChannel(guild); - const msg = '[WARN] Bot tried to assign sticky (persistant) roles to a user joining the server, but lacks the MANAGE_ROLES permission.'; + const msg = `[WARN] Bot tried to assign sticky (persistant) role "${role.name}" to a user joining the server, but lacks permissions to assign this specific role.`; if (channel) await channel.send(msg); - return; + continue; } - await guild.roles.fetch(); - for (const roleID of sticky) { - const role = guild.roles.cache.get(roleID); - if (!role) { - // ??? Should never happen - throw new Error(`Unable to find sticky role with ID ${roleID} in server ${guild.name} (${guild.id})`); - } - - if (!(await canAssignRole(bot, role))) { - // Bot can no longer assign the role. - const channel = await getLogChannel(guild); - const msg = `[WARN] Bot tried to assign sticky (persistant) role "${role.name}" to a user joining the server, but lacks permissions to assign this specific role.`; - if (channel) await channel.send(msg); - continue; - } - - await member.roles.add(role, 'Assigning sticky role to returning user'); - } - } catch (e) { - worker.release(); - throw e; + await member.roles.add(role, 'Assigning sticky role to returning user'); } })(m as Discord.GuildMember); }); @@ -227,26 +221,20 @@ client.on('guildMemberAdd', (m: Discord.GuildMember | Discord.PartialGuildMember client.on('roleDelete', (r: Discord.Role) => { void (async (role) => { const guild = role.guild; - const worker = await pgPool.connect(); - - try { - const res = await worker.query('SELECT sticky FROM servers WHERE serverid = $1', [guild.id]); - let sticky = res.rows[0].sticky; - if (!sticky.includes(role.id)) return; // Deleted role is not sticky - - // Remove references to sticky role - sticky = sticky.splice(sticky.indexOf(role.id), 1); - await worker.query('UPDATE servers SET sticky = $1 WHERE serverid = $2', [sticky, guild.id]); - - // Remove role from userlist - await worker.query( - 'UPDATE userlist SET sticky = array_remove(sticky, $1) WHERE serverid = $2 AND sticky @> ARRAY[$1]', - [role.id, guild.id] - ); - } catch (e) { - worker.release(); - throw e; - } + + const res = await database.queryWithResults('SELECT sticky FROM servers WHERE serverid = $1', [guild.id]); + let sticky = res[0].sticky; + if (!sticky.includes(role.id)) return; // Deleted role is not sticky + + // Remove references to sticky role + sticky = sticky.splice(sticky.indexOf(role.id), 1); + await database.query('UPDATE servers SET sticky = $1 WHERE serverid = $2', [sticky, guild.id]); + + // Remove role from userlist + await database.query( + 'UPDATE userlist SET sticky = array_remove(sticky, $1) WHERE serverid = $2 AND sticky @> ARRAY[$1]', + [role.id, guild.id] + ); })(r); }); @@ -259,21 +247,19 @@ client.on('guildMemberUpdate', (oldM: Discord.GuildMember | Discord.PartialGuild if (!addedRoles.length && !removedRoles.length) return; - const stickyRoles: string[] = ( - await pgPool.query('SELECT sticky FROM servers WHERE serverid = $1', [guild.id]) - ).rows[0].sticky; + const stickyRoles: string[] = (await database.queryWithResults('SELECT sticky FROM servers WHERE serverid = $1', [guild.id]))[0].sticky; addedRoles = addedRoles.filter(role => stickyRoles.includes(role.id)); removedRoles = removedRoles.filter(role => stickyRoles.includes(role.id)); if (!addedRoles.length && !removedRoles.length) return; let userRoles: string[] = ( - await pgPool.query('SELECT sticky FROM userlist WHERE serverid = $1 AND userid = $2', [guild.id, newMember.user.id]) - ).rows[0].sticky; + await database.queryWithResults('SELECT sticky FROM userlist WHERE serverid = $1 AND userid = $2', [guild.id, newMember.user.id]) + )[0].sticky; userRoles = userRoles.filter(roleID => !removedRoles.map(r => r.id).includes(roleID)); userRoles = userRoles.concat(addedRoles.map(r => r.id)); - await pgPool.query( + await database.queryWithResults( 'UPDATE userlist SET sticky = $1 WHERE serverid = $2 AND userid = $3', [userRoles, guild.id, newMember.user.id] ); diff --git a/src/lib/database.ts b/src/lib/database.ts new file mode 100644 index 0000000..0511b16 --- /dev/null +++ b/src/lib/database.ts @@ -0,0 +1,143 @@ +/** + * Database access code. + * + * Porygon-Z currently uses PostgreSQL as its database of choice. + * However, it can run using either an external Postgres server, or an in-memory Postgres database provided by pg-mem. + * The latter does not support 100% of Postgres' features; + * code depending on SQL queries which are incompatible with pg-mem can't be unit tested, + * but should theoretically run in production environments (with a real Postgres database). + * + * Porygon-Z does not currently plan to support other databases, but if it ever does, the code would go here. + * + * @author Annika + */ +import type {Pool} from 'pg'; +import type {IMemoryDb} from 'pg-mem'; + +import {escape as escapeSQL} from 'sqlutils/pg'; + +interface Query { + statement: string; + args?: any[]; +} + +export interface Database { + /** executes a query that doesn't return results */ + query(statement: string, args?: any[]): Promise; + + /** executes a query that returns results */ + queryWithResults(statement: string, args?: any[]): Promise; + + /** + * Executes several queries sequentially within a transaction. + * + * @returns true on success and false if an error occurs (in which case the transaction will be rolled back) + */ + withinTransaction(queries: Query[]): Promise; + + destroy(): Promise; +} + +export class ExternalPostgresDatabase implements Database { + private pool: Pool; + + constructor(pool: Pool) { + this.pool = pool; + } + + async query(statement: string, args?: any[]) { + await this.pool.query(statement, args); + } + + async queryWithResults(statement: string, args?: any[]) { + const result = await this.pool.query(statement, args); + return result.rows; + } + + async withinTransaction(queries: Query[]) { + const client = await this.pool.connect(); + try { + await client.query('BEGIN'); + for (const query of queries) { + await client.query(query.statement, query.args); + } + await client.query('COMMIT'); + + return true; + } catch (e) { + await client.query('ROLLBACK'); + return false; + } finally { + client.release(); + } + } + + async destroy() { + return this.pool.end(); + } +} + +export class MemoryPostgresDatabase implements Database { + private db: IMemoryDb; + + constructor(db: IMemoryDb) { + this.db = db; + } + + + query(statement: string, args?: any[]): Promise { + this.db.public.none(this.stringifyQuery(statement, args)); + return Promise.resolve(); + } + + queryWithResults(statement: string, args?: any[]): Promise { + return Promise.resolve(this.db.public.many(this.stringifyQuery(statement, args))); + } + + async withinTransaction(queries: Query[]) { + try { + await this.query('BEGIN'); + for (const query of queries) { + await this.query(query.statement, query.args); + } + await this.query('COMMIT'); + + return true; + } catch (e) { + await this.query('ROLLBACK'); + return false; + } + } + + destroy() { + return Promise.resolve(); + } + + /** + * Converts a Query (representing a parameterized statement) to a string. + * Ideally we wouldn't need to do this, but it shouldn't present too great of a security risk, + * since MemoryPostgresDatabase is only used for unit tests, not in production. + * + * This can be removed when https://github.com/oguimbal/pg-mem/issues/101 is fixed. + */ + stringifyQuery(statement: string, args?: any[]) { + if (!args?.length) return statement; + + // this is a fairly hacky solution - if only pg-mem had built-in parameterization... + // basically, we assume any instance of whitespace-"$"-digits is a parameter, and + // replace it with the given argument (sanitized, of course!). + return statement.replace(/(\s)\$(\d+)/g, (_, precedingWhitespace, indexString) => { + const index = parseInt(indexString); + if (isNaN(index) || index <= 0 || index > args.length) { + throw new Error(`Invalid index for parameterized statement: ${indexString}`); + } + + // SQL parameters start as $1, but array indices in JS start at arr[0], so we subtract 1 + return precedingWhitespace + escapeSQL(args[index - 1].toString()); + }); + } + + backup() { + return this.db.backup(); + } +} diff --git a/src/lib/sqlutils.d.ts b/src/lib/sqlutils.d.ts new file mode 100644 index 0000000..96fe2e1 --- /dev/null +++ b/src/lib/sqlutils.d.ts @@ -0,0 +1,3 @@ +declare module 'sqlutils/pg' { + function escape(statement: string): string; +} diff --git a/src/monitors/activity.ts b/src/monitors/activity.ts index 0646d49..5998109 100644 --- a/src/monitors/activity.ts +++ b/src/monitors/activity.ts @@ -4,7 +4,7 @@ * Also see src/commands/activity.ts */ import Discord = require('discord.js'); -import {pgPool} from '../common'; +import {database} from '../common'; import {BaseMonitor} from '../command_base'; // Number of days to keep lines for before pruning const LINE_PRUNE_CUTOFF = 60; @@ -16,36 +16,16 @@ async function prune() { const cutoff = new Date(); cutoff.setDate(cutoff.getDate() - LINE_PRUNE_CUTOFF); cutoff.setHours(0, 0, 0, 0); - const worker = await pgPool.connect(); - try { - await worker.query('BEGIN'); + await database.withinTransaction([ + {statement: 'DELETE FROM lines WHERE logdate < $1', args: [cutoff]}, + {statement: 'DELETE FROM channellines WHERE logdate < $1', args: [cutoff]}, + ]); - await worker.query('DELETE FROM lines WHERE logdate < $1', [cutoff]); - await worker.query('DELETE FROM channellines WHERE logdate < $1', [cutoff]); - - await worker.query('COMMIT'); - worker.release(); - - const nextPrune = new Date(); - nextPrune.setDate(nextPrune.getDate() + 1); - nextPrune.setHours(0, 0, 0, 0); - setTimeout(() => { - void prune(); - }, nextPrune.getTime() - Date.now()); - } catch (e) { - await worker.query('ROLLBACK'); - worker.release(); - - const nextPrune = new Date(); - nextPrune.setDate(nextPrune.getDate() + 1); - nextPrune.setHours(0, 0, 0, 0); - setTimeout(() => { - void prune(); - }, nextPrune.getTime() - Date.now()); - - throw e; - } + const nextPrune = new Date(); + nextPrune.setDate(nextPrune.getDate() + 1); + nextPrune.setHours(0, 0, 0, 0); + setTimeout(() => void prune(), nextPrune.getTime() - Date.now()); } // Prune any old logs on startup, also starts the timer for pruning @@ -78,7 +58,6 @@ export class ActivityMonitor extends BaseMonitor { // Should never happen, monitors do not run in PMs throw new Error('Activity monitor attempted to run outide of a guild.'); } - this.worker = await pgPool.connect(); const date = new Date(); // Log date await this.verifyData({ @@ -88,43 +67,40 @@ export class ActivityMonitor extends BaseMonitor { }); // Insert user line info - let res = await this.worker.query( + let res = await database.queryWithResults( 'SELECT * FROM lines WHERE userid = $1 AND logdate = $2 AND serverid = $3', [this.author.id, date, this.guild.id] ); - if (!res.rows.length) { + if (!res.length) { // Insert new row - await this.worker.query( + await database.queryWithResults( 'INSERT INTO lines (userid, logdate, serverid, lines) VALUES ($1, $2, $3, 1)', [this.author.id, date, this.guild.id] ); } else { // update row - await this.worker.query( + await database.query( 'UPDATE lines SET lines = lines + 1 WHERE userid = $1 AND logdate = $2 AND serverid = $3', [this.author.id, date, this.guild.id] ); } - res = await this.worker.query( + res = await database.queryWithResults( 'SELECT * FROM channellines WHERE channelid = $1 AND logdate = $2', [this.channel.id, date] ); - if (!res.rows.length) { + if (!res.length) { // Insert new row - await this.worker.query( + await database.queryWithResults( 'INSERT INTO channellines (channelid, logdate, lines) VALUES ($1, $2, 1)', [this.channel.id, date] ); } else { // Update row - await this.worker.query( + await database.queryWithResults( 'UPDATE channellines SET lines = lines + 1 WHERE channelid = $1 AND logdate = $2', [this.channel.id, date] ); } - - this.worker.release(); - this.worker = null; } } diff --git a/src/monitors/rmt.ts b/src/monitors/rmt.ts index a73828e..0e25ecc 100644 --- a/src/monitors/rmt.ts +++ b/src/monitors/rmt.ts @@ -5,7 +5,7 @@ * Also see src/commands/rmt.ts */ import Discord = require('discord.js'); -import {toID, pgPool} from '../common'; +import {toID, database} from '../common'; import {BaseMonitor} from '../command_base'; const cooldowns: {[channelid: string]: {[formatid: string]: number}} = {}; @@ -59,29 +59,29 @@ export class TeamRatingMonitor extends BaseMonitor { } async shouldExecute() { - let res = await pgPool.query('SELECT channelid FROM teamraters WHERE channelid = $1', [this.channel.id]); - if (!res.rows.length) return false; // This channel isn't setup for team rating. + let res = await database.queryWithResults('SELECT channelid FROM teamraters WHERE channelid = $1', [this.channel.id]); + if (!res.length) return false; // This channel isn't setup for team rating. if (!this.teamPasteRegexp.test(this.target)) return false; const format = this.formatRegexp.exec(this.target); if (!format || !format.length) return false; this.format = this.transformFormat(toID(format[0])); if (!this.format.startsWith('gen')) return false; - res = await pgPool.query( + res = await database.queryWithResults( 'SELECT userid FROM teamraters WHERE format = $1 AND channelid = $2', [this.format, this.channel.id] ); - if (!res.rows.length) { + if (!res.length) { return false; // No results } else { - res.rows = res.rows.filter(r => { + res = res.filter(r => { const user = this.getUser(r.userid); if (!user || user.presence.status === 'offline') return false; return true; }); - if (!res.rows.length) return false; - this.raters = res.rows.map(r => `<@${r.userid as string}>`); + if (!res.length) return false; + this.raters = res.map(r => `<@${r.userid as string}>`); } const cooldown = cooldowns[this.channel.id]?.[this.format]; if (cooldown && cooldown + (1000 * 60 * 60) >= Date.now()) return false;