Skip to content

Commit

Permalink
feat(WebSocket): allow .connect() amidst a .close()
Browse files Browse the repository at this point in the history
  • Loading branch information
kettanaito committed Mar 24, 2024
1 parent 0737094 commit 74b07c5
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 18 deletions.
16 changes: 8 additions & 8 deletions src/interceptors/WebSocket/WebSocketServerConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,36 +97,36 @@ export class WebSocketServerConnection {
*/
public connect(): void {
invariant(
this.readyState === -1,
!this.realWebSocket || this.realWebSocket.readyState !== WebSocket.OPEN,
'Failed to call "connect()" on the original WebSocket instance: the connection already open'
)

const ws = this.createConnection()
const realWebSocket = this.createConnection()

// Inherit the binary type from the mock WebSocket client.
ws.binaryType = this.socket.binaryType
realWebSocket.binaryType = this.socket.binaryType

// Close the original connection when the (mock)
// client closes, regardless of the reason.
this.socket.addEventListener(
'close',
(event) => {
ws.close(event.code, event.reason)
realWebSocket.close(event.code, event.reason)
},
{ once: true }
)

ws.addEventListener('message', (event) => {
realWebSocket.addEventListener('message', (event) => {
this.transport.onIncoming(event)
})

// Forward server errors to the WebSocket client as-is.
// We may consider exposing them to the interceptor in the future.
ws.addEventListener('error', () => {
realWebSocket.addEventListener('error', () => {
this.socket.dispatchEvent(bindEvent(this.socket, new Event('error')))
})

this.realWebSocket = ws
this.realWebSocket = realWebSocket
}

/**
Expand Down Expand Up @@ -171,7 +171,7 @@ export class WebSocketServerConnection {

invariant(
realWebSocket,
'Failed to call "server.send()" for "%s": the connection is not open. Did you forget to call "await server.connect()"?',
'Failed to call "server.send()" for "%s": the connection is not open. Did you forget to call "server.connect()"?',
this.socket.url
)

Expand Down
75 changes: 65 additions & 10 deletions test/modules/WebSocket/compliance/websocket.server.close.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
* @vitest-environment node-with-websocket
*/
import { DeferredPromise } from '@open-draft/deferred-promise'
import { RawData } from 'engine.io-parser'
import { vi, it, expect, beforeAll, afterEach, afterAll } from 'vitest'
import { WebSocketServer } from 'ws'
import { WebSocketServer, Data } from 'ws'
import {
WebSocketInterceptor,
WebSocketServerConnection,
} from '../../../../src/interceptors/WebSocket/index'
import { getWsUrl } from '../utils/getWsUrl'
import { waitForWebSocketEvent } from '../utils/waitForWebSocketEvent'
import { waitForNextTick } from '../utils/waitForNextTick'

const interceptor = new WebSocketInterceptor()
Expand All @@ -24,6 +24,7 @@ beforeAll(() => {
})

afterEach(() => {
interceptor.removeAllListeners()
wsServer.clients.forEach((client) => client.close())
})

Expand All @@ -47,13 +48,12 @@ it('throws if closing the unconnected server', async () => {
})

it('closes the actual server connection when called "server.close()"', async () => {
const clientOpenPromise = new DeferredPromise<void>()
const serverCallback = vi.fn<[number]>()
const clientMessageListener = vi.fn<[RawData]>()
const originalClientMessageListener = vi.fn<[Data]>()

wsServer.on('connection', (client) => {
client.addEventListener('message', (event) => {
clientMessageListener(event.data)
originalClientMessageListener(event.data)
})
})

Expand All @@ -80,14 +80,14 @@ it('closes the actual server connection when called "server.close()"', async ()
})

const ws = new WebSocket(getWsUrl(wsServer))
ws.onopen = () => clientOpenPromise.resolve()
ws.onerror = () => clientOpenPromise.reject()
await clientOpenPromise
await waitForWebSocketEvent('open', ws)

// Must forward the client messages to the original server.
ws.send('hello from client')
await vi.waitFor(() => {
expect(clientMessageListener).toHaveBeenCalledWith('hello from client')
expect(originalClientMessageListener).toHaveBeenCalledWith(
'hello from client'
)
})

// Must close the server connection once "server.close()" is called.
Expand All @@ -100,5 +100,60 @@ it('closes the actual server connection when called "server.close()"', async ()
// after the connection has been closed.
ws.send('another hello')
await waitForNextTick()
expect(clientMessageListener).not.toHaveBeenCalledWith('another hello')
expect(originalClientMessageListener).not.toHaveBeenCalledWith(
'another hello'
)
})

it('resumes forwarding client events to the server once it is reconnected', async () => {
const originalClientMessageListener = vi.fn<[Data]>()
wsServer.on('connection', (client) => {
client.addEventListener('message', (event) => {
originalClientMessageListener(event.data)
})
})

interceptor.once('connection', ({ client, server }) => {
server.connect()

client.addEventListener('message', (event) => {
if (event.data === 'server/close') {
server.close()
}

if (event.data === 'server/reconnect') {
server.connect()
}

server.send(event.data)
})
})

const ws = new WebSocket(getWsUrl(wsServer))
await waitForWebSocketEvent('open', ws)

ws.send('first hello')
await vi.waitFor(() => {
expect(originalClientMessageListener).toHaveBeenLastCalledWith(
'first hello'
)
})

ws.send('server/close')
await waitForNextTick()
ws.send('second hello')
await vi.waitFor(() => {
expect(originalClientMessageListener).not.toHaveBeenCalledWith(
'second hello'
)
})

ws.send('server/reconnect')
await waitForNextTick()
ws.send('third hello')
await vi.waitFor(() => {
expect(originalClientMessageListener).toHaveBeenLastCalledWith(
'third hello'
)
})
})
14 changes: 14 additions & 0 deletions test/modules/WebSocket/utils/waitForWebSocketEvent.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { DeferredPromise } from '@open-draft/deferred-promise'

/**
* Returns a Promise that resolves when the given WebSocket
* instance emits the said event.
*/
export function waitForWebSocketEvent<Type extends keyof WebSocketEventMap>(
type: Type,
ws: WebSocket
) {
const eventPromise = new DeferredPromise<void>()
ws.addEventListener(type, () => eventPromise.resolve(), { once: true })
return eventPromise
}

0 comments on commit 74b07c5

Please sign in to comment.