fix: abort message handler using listeners

This commit is contained in:
Ricardo Arturo Cabral Mejía 2022-12-24 14:01:14 -05:00
parent 4c038a6c7e
commit 7469d97221
3 changed files with 39 additions and 23 deletions

View File

@ -21,6 +21,8 @@ import { messageSchema } from '../schemas/message-schema'
const debug = createLogger('web-socket-adapter')
const debugHeartbeat = debug.extend('heartbeat')
const abortableMessageHandlers: WeakMap<WebSocket, IAbortable[]> = new WeakMap()
export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter {
public clientId: string
private clientAddress: string
@ -33,23 +35,26 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
private readonly webSocketServer: IWebSocketServerAdapter,
private readonly createMessageHandler: Factory<IMessageHandler, [IncomingMessage, IWebSocketAdapter]>,
private readonly slidingWindowRateLimiter: Factory<IRateLimiter>,
private readonly settingsFactory: Factory<ISettings>,
private readonly settings: Factory<ISettings>,
) {
super()
this.alive = true
this.subscriptions = new Map()
this.clientId = Buffer.from(this.request.headers['sec-websocket-key'], 'base64').toString('hex')
this.clientAddress = (this.request.headers['x-forwarded-for'] ?? this.request.socket.remoteAddress) as string
debug('client %s from address %s', this.clientId, this.clientAddress)
const remoteIpHeader = this.settings().network?.remote_ip_header ?? 'x-forwarded-for'
this.clientAddress = (this.request.headers[remoteIpHeader] ?? this.request.socket.remoteAddress) as string
this.client
.on('message', this.onClientMessage.bind(this))
.on('close', this.onClientClose.bind(this))
.on('pong', this.onClientPong.bind(this))
.on('error', (error) => {
debug('error', error)
if (error.name === 'RangeError' && error.message === 'Max payload size exceeded') {
debug('client %s from %s sent payload too large', this.clientId, this.clientAddress)
} else {
debug('error', error)
}
})
this
@ -60,7 +65,7 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
.on(WebSocketAdapterEvent.Broadcast, this.onBroadcast.bind(this))
.on(WebSocketAdapterEvent.Message, this.sendMessage.bind(this))
debug('client %s connected', this.clientId)
debug('client %s connected from %s', this.clientId, this.clientAddress)
}
public getClientId(): string {
@ -78,10 +83,8 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
}
public onBroadcast(event: Event): void {
debug('client %s broadcast event: %o', this.clientId, event)
this.webSocketServer.emit(WebSocketServerAdapterEvent.Broadcast, event)
if (cluster.isWorker) {
debug('client %s broadcast event to primary: %o', this.clientId, event)
process.send({
eventName: WebSocketServerAdapterEvent.Broadcast,
event,
@ -100,7 +103,6 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
}
private sendMessage(message: OutgoingMessage): void {
debug('sending message to client %s: %o', this.clientId, message)
this.client.send(JSON.stringify(message))
}
@ -127,7 +129,8 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
}
private async onClientMessage(raw: Buffer) {
let abort: () => void
let abortable = false
let messageHandler: IMessageHandler & IAbortable
try {
if (await this.isRateLimited(this.clientAddress)) {
this.sendMessage(createNoticeMessage('rate limited'))
@ -136,10 +139,13 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
const message = attemptValidation(messageSchema)(JSON.parse(raw.toString('utf8')))
const messageHandler = this.createMessageHandler([message, this]) as IMessageHandler & IAbortable
if (typeof messageHandler?.abort === 'function') {
abort = messageHandler.abort.bind(messageHandler)
this.client.prependOnceListener('close', abort)
messageHandler = this.createMessageHandler([message, this]) as IMessageHandler & IAbortable
abortable = typeof messageHandler?.abort === 'function'
if (abortable) {
const handlers = abortableMessageHandlers.get(this.client) ?? []
handlers.push(messageHandler)
abortableMessageHandlers.set(this.client, handlers)
}
await messageHandler?.handleMessage(message)
@ -150,11 +156,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
debug('invalid message: %o', (error as any).annotate())
this.sendMessage(createNoticeMessage(`Invalid message: ${error.message}`))
} else {
debug('unable to handle message: %o', error)
console.error('unable to handle message', error)
}
} finally {
if (abort) {
this.client.removeListener('close', abort)
if (abortable) {
const handlers = abortableMessageHandlers.get(this.client)
const index = handlers.indexOf(messageHandler)
if (index >= 0) {
handlers.splice(index, 1)
}
}
}
}
@ -163,10 +173,9 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
const {
rateLimits,
ipWhitelist = [],
} = this.settingsFactory().limits?.message ?? {}
} = this.settings().limits?.message ?? {}
if (ipWhitelist.includes(client)) {
debug('rate limit check %s: skipped', client)
return false
}
@ -195,8 +204,15 @@ export class WebSocketAdapter extends EventEmitter implements IWebSocketAdapter
}
private onClientClose() {
debug('client %s closing', this.clientId)
this.alive = false
this.subscriptions.clear()
const handlers = abortableMessageHandlers.get(this.client)
if (Array.isArray(handlers) && handlers.length) {
for (const handler of handlers) {
handler.abort()
}
}
this.removeAllListeners()
this.client.removeAllListeners()

View File

@ -17,7 +17,7 @@ export const workerFactory = (): AppWorker => {
const server = http.createServer()
const webSocketServer = new WebSocketServer({
server,
maxPayload: 131072, // 128 kB
maxPayload: createSettings().network?.max_payload_size ?? 131072, // 128 kB
})
const adapter = new WebSocketServerAdapter(
server,

View File

@ -70,8 +70,8 @@ export class SubscribeMessageHandler implements IMessageHandler, IAbortable {
)
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
debug('aborted: %o', error)
findEvents.end()
debug('subscription aborted: %o', error)
findEvents.destroy()
} else {
debug('error streaming events: %o', error)
}