diff --git a/lib/index.ts b/lib/index.ts index c0f5cb6..fb7ad64 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -61,13 +61,27 @@ export function createAdapter( }; } +function ensureFinaliser() { + let finaliser = () => {}; + const setFinaliser = (value) => (finaliser = value); + const wrapWithFinaliser = (fn) => (...rest) => { + try { + return fn(...rest); + } finally { + finaliser(); + } + }; + + return [wrapWithFinaliser, setFinaliser]; +} + export class RedisAdapter extends Adapter { public readonly uid; public readonly requestsTimeout: number; private readonly channel: string; private readonly requestChannel: string; - private readonly responseChannel: string; + private readonly responseChannelPrefix: string; private requests: Map = new Map(); /** @@ -95,33 +109,35 @@ export class RedisAdapter extends Adapter { this.channel = prefix + "#" + nsp.name + "#"; this.requestChannel = prefix + "-request#" + this.nsp.name + "#"; - this.responseChannel = prefix + "-response#" + this.nsp.name + "#"; - - const onError = (err) => { - if (err) { - this.emit("error", err); - } - }; + this.responseChannelPrefix = prefix + "-response#" + this.nsp.name + "#"; - this.subClient.psubscribe(this.channel + "*", onError); - this.subClient.on("pmessageBuffer", this.onmessage.bind(this)); + this.subClient.psubscribe(this.channel + "*", this.onError); + this.subClient.on("pmessageBuffer", this.onmessage); - this.subClient.subscribe( - [this.requestChannel, this.responseChannel], - onError - ); - this.subClient.on("messageBuffer", this.onrequest.bind(this)); + this.subClient.subscribe([this.requestChannel], this.onError); + this.subClient.on("messageBuffer", this.onrequest); - this.pubClient.on("error", onError); - this.subClient.on("error", onError); + this.pubClient.on("error", this.onError); + this.subClient.on("error", this.onError); } + /** + * Called on errors + * + * @private + */ + private onError = (err) => { + if (err) { + this.emit("error", err); + } + }; + /** * Called with a subscription message * * @private */ - private onmessage(pattern, channel, msg) { + private onmessage = (pattern, channel, msg) => { channel = channel.toString(); const channelMatches = channel.startsWith(this.channel); @@ -150,6 +166,37 @@ export class RedisAdapter extends Adapter { opts.except = new Set(opts.except); super.broadcast(packet, opts); + }; + + /** + * Called to start request + * + * @private + */ + private startrequest(requestId, request) { + if (requestId) { + this.subClient.subscribe([this.responseChannel(requestId)], this.onError); + } + + this.pubClient.publish(this.requestChannel, request); + } + + /** + * Called to end request and cleanup any response subscriptions + * + * @private + */ + private endrequest(requestId) { + this.subClient.unsubscribe([this.responseChannel(requestId)], this.onError); + } + + /** + * Called to get response channel id + * + * @private + */ + private responseChannel(requestId) { + return this.responseChannelPrefix + requestId; } /** @@ -157,10 +204,10 @@ export class RedisAdapter extends Adapter { * * @private */ - private async onrequest(channel, msg) { + private onrequest = async (channel, msg) => { channel = channel.toString(); - if (channel.startsWith(this.responseChannel)) { + if (channel.startsWith(this.responseChannelPrefix)) { return this.onresponse(channel, msg); } else if (!channel.startsWith(this.requestChannel)) { return debug("ignore different channel"); @@ -192,7 +239,10 @@ export class RedisAdapter extends Adapter { sockets: [...sockets], }); - this.pubClient.publish(this.responseChannel, response); + this.pubClient.publish( + this.responseChannel(request.requestId), + response + ); break; case RequestType.ALL_ROOMS: @@ -205,7 +255,10 @@ export class RedisAdapter extends Adapter { rooms: [...this.rooms.keys()], }); - this.pubClient.publish(this.responseChannel, response); + this.pubClient.publish( + this.responseChannel(request.requestId), + response + ); break; case RequestType.REMOTE_JOIN: @@ -228,7 +281,10 @@ export class RedisAdapter extends Adapter { requestId: request.requestId, }); - this.pubClient.publish(this.responseChannel, response); + this.pubClient.publish( + this.responseChannel(request.requestId), + response + ); break; case RequestType.REMOTE_LEAVE: @@ -251,7 +307,10 @@ export class RedisAdapter extends Adapter { requestId: request.requestId, }); - this.pubClient.publish(this.responseChannel, response); + this.pubClient.publish( + this.responseChannel(request.requestId), + response + ); break; case RequestType.REMOTE_DISCONNECT: @@ -274,7 +333,10 @@ export class RedisAdapter extends Adapter { requestId: request.requestId, }); - this.pubClient.publish(this.responseChannel, response); + this.pubClient.publish( + this.responseChannel(request.requestId), + response + ); break; case RequestType.REMOTE_FETCH: @@ -298,7 +360,10 @@ export class RedisAdapter extends Adapter { })), }); - this.pubClient.publish(this.responseChannel, response); + this.pubClient.publish( + this.responseChannel(request.requestId), + response + ); break; case RequestType.SERVER_SIDE_EMIT: @@ -320,7 +385,7 @@ export class RedisAdapter extends Adapter { called = true; debug("calling acknowledgement with %j", arg); this.pubClient.publish( - this.responseChannel, + this.responseChannel(request.requestId), JSON.stringify({ type: RequestType.SERVER_SIDE_EMIT, requestId: request.requestId, @@ -335,7 +400,7 @@ export class RedisAdapter extends Adapter { default: debug("ignoring unknown request type: %s", request.type); } - } + }; /** * Called on response from another node @@ -487,6 +552,10 @@ export class RedisAdapter extends Adapter { }); return new Promise((resolve, reject) => { + const [wrapWithFinaliser, setFinaliser] = ensureFinaliser(); + resolve = wrapWithFinaliser(resolve); + reject = wrapWithFinaliser(reject); + const timeout = setTimeout(() => { if (this.requests.has(requestId)) { reject( @@ -505,7 +574,10 @@ export class RedisAdapter extends Adapter { sockets: localSockets, }); - this.pubClient.publish(this.requestChannel, request); + setFinaliser(() => { + this.endrequest(requestId); + }); + this.startrequest(requestId, request); }); } @@ -530,6 +602,10 @@ export class RedisAdapter extends Adapter { }); return new Promise((resolve, reject) => { + const [wrapWithFinaliser, setFinaliser] = ensureFinaliser(); + resolve = wrapWithFinaliser(resolve); + reject = wrapWithFinaliser(reject); + const timeout = setTimeout(() => { if (this.requests.has(requestId)) { reject( @@ -548,7 +624,10 @@ export class RedisAdapter extends Adapter { rooms: localRooms, }); - this.pubClient.publish(this.requestChannel, request); + setFinaliser(() => { + this.endrequest(requestId); + }); + this.startrequest(requestId, request); }); } @@ -576,6 +655,10 @@ export class RedisAdapter extends Adapter { }); return new Promise((resolve, reject) => { + const [wrapWithFinaliser, setFinaliser] = ensureFinaliser(); + resolve = wrapWithFinaliser(resolve); + reject = wrapWithFinaliser(reject); + const timeout = setTimeout(() => { if (this.requests.has(requestId)) { reject( @@ -591,7 +674,10 @@ export class RedisAdapter extends Adapter { timeout, }); - this.pubClient.publish(this.requestChannel, request); + setFinaliser(() => { + this.endrequest(requestId); + }); + this.startrequest(requestId, request); }); } @@ -619,6 +705,10 @@ export class RedisAdapter extends Adapter { }); return new Promise((resolve, reject) => { + const [wrapWithFinaliser, setFinaliser] = ensureFinaliser(); + resolve = wrapWithFinaliser(resolve); + reject = wrapWithFinaliser(reject); + const timeout = setTimeout(() => { if (this.requests.has(requestId)) { reject( @@ -634,7 +724,10 @@ export class RedisAdapter extends Adapter { timeout, }); - this.pubClient.publish(this.requestChannel, request); + setFinaliser(() => { + this.endrequest(requestId); + }); + this.startrequest(requestId, request); }); } @@ -662,6 +755,10 @@ export class RedisAdapter extends Adapter { }); return new Promise((resolve, reject) => { + const [wrapWithFinaliser, setFinaliser] = ensureFinaliser(); + resolve = wrapWithFinaliser(resolve); + reject = wrapWithFinaliser(reject); + const timeout = setTimeout(() => { if (this.requests.has(requestId)) { reject( @@ -679,7 +776,10 @@ export class RedisAdapter extends Adapter { timeout, }); - this.pubClient.publish(this.requestChannel, request); + setFinaliser(() => { + this.endrequest(requestId); + }); + this.startrequest(requestId, request); }); } @@ -709,6 +809,10 @@ export class RedisAdapter extends Adapter { }); return new Promise((resolve, reject) => { + const [wrapWithFinaliser, setFinaliser] = ensureFinaliser(); + resolve = wrapWithFinaliser(resolve); + reject = wrapWithFinaliser(reject); + const timeout = setTimeout(() => { if (this.requests.has(requestId)) { reject( @@ -727,7 +831,10 @@ export class RedisAdapter extends Adapter { sockets: localSockets, }); - this.pubClient.publish(this.requestChannel, request); + setFinaliser(() => { + this.endrequest(requestId); + }); + this.startrequest(requestId, request); }); } @@ -745,7 +852,7 @@ export class RedisAdapter extends Adapter { rooms: [...rooms], }); - this.pubClient.publish(this.requestChannel, request); + this.startrequest(null, request); } public delSockets(opts: BroadcastOptions, rooms: Room[]) { @@ -762,7 +869,7 @@ export class RedisAdapter extends Adapter { rooms: [...rooms], }); - this.pubClient.publish(this.requestChannel, request); + this.startrequest(null, request); } public disconnectSockets(opts: BroadcastOptions, close: boolean) { @@ -779,7 +886,7 @@ export class RedisAdapter extends Adapter { close, }); - this.pubClient.publish(this.requestChannel, request); + this.startrequest(null, request); } public serverSideEmit(packet: any[]): void { @@ -798,11 +905,12 @@ export class RedisAdapter extends Adapter { data: packet, }); - this.pubClient.publish(this.requestChannel, request); + this.startrequest(null, request); } private async serverSideEmitWithAck(packet: any[]) { - const ack = packet.pop(); + const [wrapWithFinaliser, setFinaliser] = ensureFinaliser(); + const ack = wrapWithFinaliser(packet.pop()); const numSub = (await this.getNumSub()) - 1; // ignore self debug('waiting for %d responses to "serverSideEmit" request', numSub); @@ -840,7 +948,10 @@ export class RedisAdapter extends Adapter { responses: [], }); - this.pubClient.publish(this.requestChannel, request); + setFinaliser(() => { + this.endrequest(requestId); + }); + this.startrequest(requestId, request); } /**