fix: abort message handler using listeners

This commit is contained in:
Ricardo Arturo Cabral Mejía
2022-12-24 14:01:14 -05:00
parent 4c038a6c7e
commit 7469d97221
3 changed files with 39 additions and 23 deletions

View File

@@ -21,6 +21,8 @@ import { messageSchema } from '../schemas/message-schema'
const debug = createLogger('web-socket-adapter') const debug = createLogger('web-socket-adapter')
const debugHeartbeat = debug.extend('heartbeat') const debugHeartbeat = debug.extend('heartbeat')
const abortableMessageHandlers: WeakMap<WebSocket, IAbortable[]> = new WeakMap()
export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter { export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter {
public clientId: string public clientId: string
private clientAddress: string private clientAddress: string
@@ -33,23 +35,26 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
private readonly webSocketServer: IWebSocketServerAdapter, private readonly webSocketServer: IWebSocketServerAdapter,
private readonly createMessageHandler: Factory<IMessageHandler, [IncomingMessage, IWebSocketAdapter]>, private readonly createMessageHandler: Factory<IMessageHandler, [IncomingMessage, IWebSocketAdapter]>,
private readonly slidingWindowRateLimiter: Factory<IRateLimiter>, private readonly slidingWindowRateLimiter: Factory<IRateLimiter>,
private readonly settingsFactory: Factory<ISettings>, private readonly settings: Factory<ISettings>,
) { ) {
super() super()
this.alive = true this.alive = true
this.subscriptions = new Map() this.subscriptions = new Map()
this.clientId = Buffer.from(this.request.headers['sec-websocket-key'], 'base64').toString('hex') this.clientId = Buffer.from(this.request.headers['sec-websocket-key'], 'base64').toString('hex')
this.clientAddress = (this.request.headers['x-forwarded-for'] ?? this.request.socket.remoteAddress) as string const remoteIpHeader = this.settings().network?.remote_ip_header ?? 'x-forwarded-for'
this.clientAddress = (this.request.headers[remoteIpHeader] ?? this.request.socket.remoteAddress) as string
debug('client %s from address %s', this.clientId, this.clientAddress)
this.client this.client
.on('message', this.onClientMessage.bind(this)) .on('message', this.onClientMessage.bind(this))
.on('close', this.onClientClose.bind(this)) .on('close', this.onClientClose.bind(this))
.on('pong', this.onClientPong.bind(this)) .on('pong', this.onClientPong.bind(this))
.on('error', (error) => { .on('error', (error) => {
debug('error', error) if (error.name === 'RangeError' && error.message === 'Max payload size exceeded') {
debug('client %s from %s sent payload too large', this.clientId, this.clientAddress)
} else {
debug('error', error)
}
}) })
this this
@@ -60,7 +65,7 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
.on(WebSocketAdapterEvent.Broadcast, this.onBroadcast.bind(this)) .on(WebSocketAdapterEvent.Broadcast, this.onBroadcast.bind(this))
.on(WebSocketAdapterEvent.Message, this.sendMessage.bind(this)) .on(WebSocketAdapterEvent.Message, this.sendMessage.bind(this))
debug('client %s connected', this.clientId) debug('client %s connected from %s', this.clientId, this.clientAddress)
} }
public getClientId(): string { public getClientId(): string {
@@ -78,10 +83,8 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
} }
public onBroadcast(event: Event): void { public onBroadcast(event: Event): void {
debug('client %s broadcast event: %o', this.clientId, event)
this.webSocketServer.emit(WebSocketServerAdapterEvent.Broadcast, event) this.webSocketServer.emit(WebSocketServerAdapterEvent.Broadcast, event)
if (cluster.isWorker) { if (cluster.isWorker) {
debug('client %s broadcast event to primary: %o', this.clientId, event)
process.send({ process.send({
eventName: WebSocketServerAdapterEvent.Broadcast, eventName: WebSocketServerAdapterEvent.Broadcast,
event, event,
@@ -100,7 +103,6 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
} }
private sendMessage(message: OutgoingMessage): void { private sendMessage(message: OutgoingMessage): void {
debug('sending message to client %s: %o', this.clientId, message)
this.client.send(JSON.stringify(message)) this.client.send(JSON.stringify(message))
} }
@@ -127,7 +129,8 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
} }
private async onClientMessage(raw: Buffer) { private async onClientMessage(raw: Buffer) {
let abort: () => void let abortable = false
let messageHandler: IMessageHandler & IAbortable
try { try {
if (await this.isRateLimited(this.clientAddress)) { if (await this.isRateLimited(this.clientAddress)) {
this.sendMessage(createNoticeMessage('rate limited')) this.sendMessage(createNoticeMessage('rate limited'))
@@ -136,10 +139,13 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
const message = attemptValidation(messageSchema)(JSON.parse(raw.toString('utf8'))) const message = attemptValidation(messageSchema)(JSON.parse(raw.toString('utf8')))
const messageHandler = this.createMessageHandler([message, this]) as IMessageHandler & IAbortable messageHandler = this.createMessageHandler([message, this]) as IMessageHandler & IAbortable
if (typeof messageHandler?.abort === 'function') { abortable = typeof messageHandler?.abort === 'function'
abort = messageHandler.abort.bind(messageHandler)
this.client.prependOnceListener('close', abort) if (abortable) {
const handlers = abortableMessageHandlers.get(this.client) ?? []
handlers.push(messageHandler)
abortableMessageHandlers.set(this.client, handlers)
} }
await messageHandler?.handleMessage(message) await messageHandler?.handleMessage(message)
@@ -150,11 +156,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
debug('invalid message: %o', (error as any).annotate()) debug('invalid message: %o', (error as any).annotate())
this.sendMessage(createNoticeMessage(`Invalid message: ${error.message}`)) this.sendMessage(createNoticeMessage(`Invalid message: ${error.message}`))
} else { } else {
debug('unable to handle message: %o', error) console.error('unable to handle message', error)
} }
} finally { } finally {
if (abort) { if (abortable) {
this.client.removeListener('close', abort) const handlers = abortableMessageHandlers.get(this.client)
const index = handlers.indexOf(messageHandler)
if (index >= 0) {
handlers.splice(index, 1)
}
} }
} }
} }
@@ -163,10 +173,9 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
const { const {
rateLimits, rateLimits,
ipWhitelist = [], ipWhitelist = [],
} = this.settingsFactory().limits?.message ?? {} } = this.settings().limits?.message ?? {}
if (ipWhitelist.includes(client)) { if (ipWhitelist.includes(client)) {
debug('rate limit check %s: skipped', client)
return false return false
} }
@@ -195,8 +204,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
} }
private onClientClose() { private onClientClose() {
debug('client %s closing', this.clientId)
this.alive = false this.alive = false
this.subscriptions.clear()
const handlers = abortableMessageHandlers.get(this.client)
if (Array.isArray(handlers) && handlers.length) {
for (const handler of handlers) {
handler.abort()
}
}
this.removeAllListeners() this.removeAllListeners()
this.client.removeAllListeners() this.client.removeAllListeners()

View File

@@ -17,7 +17,7 @@ export const workerFactory = (): AppWorker => {
const server = http.createServer() const server = http.createServer()
const webSocketServer = new WebSocketServer({ const webSocketServer = new WebSocketServer({
server, server,
maxPayload: 131072, // 128 kB maxPayload: createSettings().network?.max_payload_size ?? 131072, // 128 kB
}) })
const adapter = new WebSocketServerAdapter( const adapter = new WebSocketServerAdapter(
server, server,

View File

@@ -70,8 +70,8 @@ export class SubscribeMessageHandler implements IMessageHandler, IAbortable {
) )
} catch (error) { } catch (error) {
if (error instanceof Error && error.name === 'AbortError') { if (error instanceof Error && error.name === 'AbortError') {
debug('aborted: %o', error) debug('subscription aborted: %o', error)
findEvents.end() findEvents.destroy()
} else { } else {
debug('error streaming events: %o', error) debug('error streaming events: %o', error)
} }