diff --git a/connection.ts b/connection.ts index a90db426..633f26b1 100644 --- a/connection.ts +++ b/connection.ts @@ -8,6 +8,7 @@ import { kUnstableCreateProtocol, kUnstablePipeline, kUnstableReadReply, + kUnstableWriteCommand, } from "./internal/symbols.ts"; import { delay } from "./vendor/https/deno.land/std/async/delay.ts"; @@ -35,6 +36,10 @@ export interface Connection { * @private */ [kUnstableReadReply](returnsUint8Arrays?: boolean): Promise; + /** + * @private + */ + [kUnstableWriteCommand](command: Command): Promise; /** * @private */ @@ -170,6 +175,10 @@ export class RedisConnection implements Connection { return this.#protocol.pipeline(commands); } + [kUnstableWriteCommand](command: Command): Promise { + return this.#protocol.writeCommand(command); + } + /** * Connect to Redis server */ diff --git a/internal/symbols.ts b/internal/symbols.ts index 8357b60e..3f49fe16 100644 --- a/internal/symbols.ts +++ b/internal/symbols.ts @@ -3,6 +3,11 @@ */ export const kUnstableReadReply = Symbol("deno-redis.readReply"); +/** + * @private + */ +export const kUnstableWriteCommand = Symbol("deno-redis.writeCommand"); + /** * @private */ diff --git a/protocol/deno_streams/command.ts b/protocol/deno_streams/command.ts index 5e62de7a..7f99707d 100644 --- a/protocol/deno_streams/command.ts +++ b/protocol/deno_streams/command.ts @@ -6,7 +6,7 @@ import type { RedisReply, RedisValue } from "../shared/types.ts"; import { encodeCommand } from "../shared/command.ts"; import type { Command } from "../shared/protocol.ts"; -async function writeRequest( +export async function writeCommand( writer: BufWriter, command: string, args: RedisValue[], @@ -22,7 +22,7 @@ export async function sendCommand( args: RedisValue[], returnUint8Arrays?: boolean, ): Promise { - await writeRequest(writer, command, args); + await writeCommand(writer, command, args); await writer.flush(); return readReply(reader, returnUint8Arrays); } @@ -33,7 +33,7 @@ export async function sendCommands( commands: Command[], ): Promise<(RedisReply | ErrorReplyError)[]> { for (const { command, args } of commands) { - await writeRequest(writer, command, args); + await writeCommand(writer, command, args); } await writer.flush(); const ret: (RedisReply | ErrorReplyError)[] = []; diff --git a/protocol/deno_streams/mod.ts b/protocol/deno_streams/mod.ts index e9968636..c43200bc 100644 --- a/protocol/deno_streams/mod.ts +++ b/protocol/deno_streams/mod.ts @@ -1,7 +1,7 @@ import { BufReader } from "../../vendor/https/deno.land/std/io/buf_reader.ts"; import { BufWriter } from "../../vendor/https/deno.land/std/io/buf_writer.ts"; import { readReply } from "./reply.ts"; -import { sendCommand, sendCommands } from "./command.ts"; +import { sendCommand, sendCommands, writeCommand } from "./command.ts"; import type { Command, Protocol as BaseProtocol } from "../shared/protocol.ts"; import { RedisReply, RedisValue } from "../shared/types.ts"; @@ -34,6 +34,11 @@ export class Protocol implements BaseProtocol { return readReply(this.#reader, returnsUint8Arrays); } + async writeCommand(command: Command): Promise { + await writeCommand(this.#writer, command.command, command.args); + await this.#writer.flush(); + } + pipeline(commands: Command[]): Promise> { return sendCommands(this.#writer, this.#reader, commands); } diff --git a/protocol/shared/protocol.ts b/protocol/shared/protocol.ts index 7c203501..810cdf34 100644 --- a/protocol/shared/protocol.ts +++ b/protocol/shared/protocol.ts @@ -14,6 +14,7 @@ export interface Protocol { returnsUint8Arrays?: boolean, ): Promise; readReply(returnsUint8Array?: boolean): Promise; + writeCommand(command: Command): Promise; pipeline( commands: Array, ): Promise>; diff --git a/protocol/web_streams/command.ts b/protocol/web_streams/command.ts index 97ccef59..fbbcf34b 100644 --- a/protocol/web_streams/command.ts +++ b/protocol/web_streams/command.ts @@ -4,7 +4,7 @@ import type { BufferedReadableStream } from "../../internal/buffered_readable_st import type { RedisReply, RedisValue } from "../shared/types.ts"; import { encodeCommand, encodeCommands } from "../shared/command.ts"; -async function writeRequest( +export async function writeCommand( writable: WritableStream, command: string, args: RedisValue[], @@ -25,7 +25,7 @@ export async function sendCommand( args: RedisValue[], returnUint8Arrays?: boolean, ): Promise { - await writeRequest(writable, command, args); + await writeCommand(writable, command, args); return readReply(readable, returnUint8Arrays); } diff --git a/protocol/web_streams/mod.ts b/protocol/web_streams/mod.ts index cbb0f7a8..b6f9734f 100644 --- a/protocol/web_streams/mod.ts +++ b/protocol/web_streams/mod.ts @@ -1,4 +1,4 @@ -import { sendCommand, sendCommands } from "./command.ts"; +import { sendCommand, sendCommands, writeCommand } from "./command.ts"; import { readReply } from "./reply.ts"; import type { Command, Protocol as BaseProtocol } from "../shared/protocol.ts"; import { RedisReply, RedisValue } from "../shared/types.ts"; @@ -30,6 +30,10 @@ export class Protocol implements BaseProtocol { return readReply(this.#readable, returnsUint8Arrays); } + writeCommand(command: Command): Promise { + return writeCommand(this.#writable, command.command, command.args); + } + pipeline(commands: Command[]): Promise> { return sendCommands(this.#writable, this.#readable, commands); } diff --git a/pubsub.ts b/pubsub.ts index 8aeb13a6..e6e7a742 100644 --- a/pubsub.ts +++ b/pubsub.ts @@ -2,7 +2,10 @@ import type { CommandExecutor } from "./executor.ts"; import { isRetriableError } from "./errors.ts"; import type { Binary } from "./protocol/shared/types.ts"; import { decoder } from "./internal/encoding.ts"; -import { kUnstableReadReply } from "./internal/symbols.ts"; +import { + kUnstableReadReply, + kUnstableWriteCommand, +} from "./internal/symbols.ts"; type DefaultMessageType = string; type ValidMessageType = string | string[]; @@ -43,28 +46,28 @@ class RedisSubscriptionImpl< constructor(private executor: CommandExecutor) {} async psubscribe(...patterns: string[]) { - await this.executor.exec("PSUBSCRIBE", ...patterns); + await this.#writeCommand("PSUBSCRIBE", patterns); for (const pat of patterns) { this.patterns[pat] = true; } } async punsubscribe(...patterns: string[]) { - await this.executor.exec("PUNSUBSCRIBE", ...patterns); + await this.#writeCommand("PUNSUBSCRIBE", patterns); for (const pat of patterns) { delete this.patterns[pat]; } } async subscribe(...channels: string[]) { - await this.executor.exec("SUBSCRIBE", ...channels); + await this.#writeCommand("SUBSCRIBE", channels); for (const chan of channels) { this.channels[chan] = true; } } async unsubscribe(...channels: string[]) { - await this.executor.exec("UNSUBSCRIBE", ...channels); + await this.#writeCommand("UNSUBSCRIBE", channels); for (const chan of channels) { delete this.channels[chan]; } @@ -155,6 +158,10 @@ class RedisSubscriptionImpl< close() { this.executor.connection.close(); } + + async #writeCommand(command: string, args: Array): Promise { + await this.executor.connection[kUnstableWriteCommand]({ command, args }); + } } export async function subscribe< diff --git a/redis.ts b/redis.ts index b4ed6519..8e4b2f85 100644 --- a/redis.ts +++ b/redis.ts @@ -58,6 +58,7 @@ import type { SimpleString, } from "./protocol/shared/types.ts"; import { createRedisPipeline } from "./pipeline.ts"; +import type { RedisSubscription } from "./pubsub.ts"; import { psubscribe, subscribe } from "./pubsub.ts"; import { convertMap, @@ -1180,16 +1181,30 @@ class RedisImpl implements Redis { return this.execIntegerReply("PUBLISH", channel, message); } - subscribe( + // deno-lint-ignore no-explicit-any + #subscription?: RedisSubscription; + async subscribe( ...channels: string[] ) { - return subscribe(this.executor, ...channels); + if (this.#subscription) { + await this.#subscription.subscribe(...channels); + return this.#subscription; + } + const subscription = await subscribe(this.executor, ...channels); + this.#subscription = subscription; + return subscription; } - psubscribe( + async psubscribe( ...patterns: string[] ) { - return psubscribe(this.executor, ...patterns); + if (this.#subscription) { + await this.#subscription.psubscribe(...patterns); + return this.#subscription; + } + const subscription = await psubscribe(this.executor, ...patterns); + this.#subscription = subscription; + return subscription; } pubsubChannels(pattern?: string) { diff --git a/tests/commands/pubsub.ts b/tests/commands/pubsub.ts index 2d3deac2..35d7c737 100644 --- a/tests/commands/pubsub.ts +++ b/tests/commands/pubsub.ts @@ -14,7 +14,7 @@ export function pubsubTests( ): void { const getOpts = () => ({ hostname: "127.0.0.1", port: getServer().port }); - it("subscribe() & unsubscribe()", async () => { + it("supports unsubscribing channels by `unsubscribe()`", async () => { const opts = getOpts(); const client = await connect(opts); const sub = await client.subscribe("subsc"); @@ -24,7 +24,7 @@ export function pubsubTests( client.close(); }); - it("receive()", async () => { + it("supports reading messages sequentially by `receive()`", async () => { const opts = getOpts(); const client = await connect(opts); const pub = await connect(opts); @@ -74,7 +74,7 @@ export function pubsubTests( }); }); - it("psubscribe()", async () => { + it("supports `psubscribe()`", async () => { const opts = getOpts(); const client = await connect(opts); const pub = await connect(opts); @@ -104,7 +104,7 @@ export function pubsubTests( client.close(); }); - it("retry", async () => { + it("supports automatic reconnection of subscribers", async () => { const opts = getOpts(); const port = nextPort(); let tempServer = await startRedis({ port }); @@ -126,11 +126,14 @@ export function pubsubTests( messages++; }, 900); + // Intentionally stops the server after the first message is delivered. setTimeout(() => stopRedis(tempServer), 1000); const { promise, resolve, reject } = Promise.withResolvers(); setTimeout(async () => { try { + // At this point, the server is assumed to be stopped. + // The subscriber and publisher should attempt to reconnect. assertEquals( subscriberClient.isConnected, false, @@ -141,14 +144,17 @@ export function pubsubTests( false, "The publisher client still thinks it is connected.", ); + assert(messages >= 1, "At least one message should be published."); assert(messages < 5, "Too many messages were published."); + // Reboot the server. tempServer = await startRedis({ port }); const tempClient = await connect({ ...opts, port }); await tempClient.ping(); tempClient.close(); + // Wait for the subscriber and publisher to reconnect... await delay(1000); assert( @@ -193,7 +199,7 @@ export function pubsubTests( }, }); - it("pubsubNumsub()", async () => { + it("supports `pubsubNumsub()`", async () => { const opts = getOpts(); const subClient1 = await connect(opts); await subClient1.subscribe("test1", "test2"); @@ -209,4 +215,46 @@ export function pubsubTests( subClient2.close(); pubClient.close(); }); + + it("supports calling `subscribe()` multiple times", async () => { + // https://github.com/denodrivers/redis/issues/390 + const opts = getOpts(); + const redis = await connect(opts); + const pub = await connect(opts); + const channel1 = "foo"; + const channel2 = "bar"; + + // First subscription + const sub1 = await redis.subscribe(channel1); + const it1 = sub1.receive(); + const promise1 = it1.next(); + try { + // Second subscription + const sub2 = await redis.subscribe(channel2); + try { + const message = "A"; + await pub.publish(channel1, message); + const result = await promise1; + assert(!result.done); + assertEquals(result.value, { channel: channel1, message }); + + const it2 = sub2.receive(); + const promise2 = it2.next(); + const message2 = "B"; + await pub.publish(channel2, message2); + const result2 = await promise2; + assert(!result2.done); + assertEquals(result2.value, { + channel: channel2, + message: message2, + }); + } finally { + sub2.close(); + } + } finally { + pub.close(); + sub1.close(); + redis.close(); + } + }); }