From aca4422f5d4b81576d8c3cc5531cef7b7491abd2 Mon Sep 17 00:00:00 2001 From: Marcus Pousette Date: Wed, 17 May 2023 11:11:33 +0200 Subject: [PATCH] fix: restrict message sizes to 16kb (#147) If the datachannel buffer grows to larger than 256kb it will close if you are using Chrome, so if it grows too large, wait for the `bufferedamountlow` event before continuing to send data. Also split data into 16kb while sending to ensure maximum cross browser compatibility. Fixes #144 Fixes #158 --------- Co-authored-by: Alex Potsides --- package.json | 1 + src/index.ts | 25 +++++-- src/muxer.ts | 54 +++++++------- src/private-to-private/handler.ts | 13 ++-- src/private-to-private/transport.ts | 8 ++- src/private-to-public/transport.ts | 15 ++-- src/stream.ts | 51 ++++++++++++- test/stream.spec.ts | 108 ++++++++++++++++++++++++++++ 8 files changed, 227 insertions(+), 48 deletions(-) create mode 100644 test/stream.spec.ts diff --git a/package.json b/package.json index 2b9e4cb..179e199 100644 --- a/package.json +++ b/package.json @@ -158,6 +158,7 @@ "multiformats": "^11.0.2", "multihashes": "^4.0.3", "p-defer": "^4.0.0", + "p-event": "^5.0.1", "protons-runtime": "^5.0.0", "uint8arraylist": "^2.4.3", "uint8arrays": "^4.0.3" diff --git a/src/index.ts b/src/index.ts index 02cd39e..b35a16c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,14 +1,31 @@ import { WebRTCTransport } from './private-to-private/transport.js' -import { WebRTCDirectTransport, type WebRTCDirectTransportComponents } from './private-to-public/transport.js' +import { WebRTCDirectTransport, type WebRTCTransportDirectInit, type WebRTCDirectTransportComponents } from './private-to-public/transport.js' import type { WebRTCTransportComponents, WebRTCTransportInit } from './private-to-private/transport.js' import type { Transport } from '@libp2p/interface-transport' -function webRTCDirect (): (components: WebRTCDirectTransportComponents) => Transport { - return (components: WebRTCDirectTransportComponents) => new WebRTCDirectTransport(components) +/** + * @param {WebRTCTransportDirectInit} init - WebRTC direct transport configuration + * @param init.dataChannel - DataChannel configurations + * @param {number} init.dataChannel.maxMessageSize - Max message size that can be sent through the DataChannel. Larger messages will be chunked into smaller messages below this size (default 16kb) + * @param {number} init.dataChannel.maxBufferedAmount - Max buffered amount a DataChannel can have (default 16mb) + * @param {number} init.dataChannel.bufferedAmountLowEventTimeout - If max buffered amount is reached, this is the max time that is waited before the buffer is cleared (default 30 seconds) + * @returns + */ +function webRTCDirect (init?: WebRTCTransportDirectInit): (components: WebRTCDirectTransportComponents) => Transport { + return (components: WebRTCDirectTransportComponents) => new WebRTCDirectTransport(components, init) } +/** + * @param {WebRTCTransportInit} init - WebRTC transport configuration + * @param {RTCConfiguration} init.rtcConfiguration - RTCConfiguration + * @param init.dataChannel - DataChannel configurations + * @param {number} init.dataChannel.maxMessageSize - Max message size that can be sent through the DataChannel. Larger messages will be chunked into smaller messages below this size (default 16kb) + * @param {number} init.dataChannel.maxBufferedAmount - Max buffered amount a DataChannel can have (default 16mb) + * @param {number} init.dataChannel.bufferedAmountLowEventTimeout - If max buffered amount is reached, this is the max time that is waited before the buffer is cleared (default 30 seconds) + * @returns + */ function webRTC (init?: WebRTCTransportInit): (components: WebRTCTransportComponents) => Transport { - return (components: WebRTCTransportComponents) => new WebRTCTransport(components, init ?? {}) + return (components: WebRTCTransportComponents) => new WebRTCTransport(components, init) } export { webRTC, webRTCDirect } diff --git a/src/muxer.ts b/src/muxer.ts index 5d96125..2487be0 100644 --- a/src/muxer.ts +++ b/src/muxer.ts @@ -1,4 +1,4 @@ -import { WebRTCStream } from './stream.js' +import { type DataChannelOpts, WebRTCStream } from './stream.js' import { nopSink, nopSource } from './util.js' import type { Stream } from '@libp2p/interface-connection' import type { CounterGroup } from '@libp2p/interface-metrics' @@ -7,39 +7,48 @@ import type { Source, Sink } from 'it-stream-types' import type { Uint8ArrayList } from 'uint8arraylist' export interface DataChannelMuxerFactoryInit { + /** + * WebRTC Peer Connection + */ peerConnection: RTCPeerConnection + + /** + * Optional metrics for this data channel muxer + */ metrics?: CounterGroup + + /** + * Data channel options + */ + dataChannelOptions?: Partial } export class DataChannelMuxerFactory implements StreamMuxerFactory { /** * WebRTC Peer Connection */ - private readonly peerConnection: RTCPeerConnection private streamBuffer: WebRTCStream[] = [] - private readonly metrics?: CounterGroup - constructor (peerConnection: RTCPeerConnection, metrics?: CounterGroup, readonly protocol = '/webrtc') { - this.peerConnection = peerConnection + constructor (readonly init: DataChannelMuxerFactoryInit, readonly protocol = '/webrtc') { // store any datachannels opened before upgrade has been completed - this.peerConnection.ondatachannel = ({ channel }) => { + this.init.peerConnection.ondatachannel = ({ channel }) => { const stream = new WebRTCStream({ channel, stat: { direction: 'inbound', timeline: { open: 0 } }, + dataChannelOptions: init.dataChannelOptions, closeCb: (_stream) => { this.streamBuffer = this.streamBuffer.filter(s => !_stream.eq(s)) } }) this.streamBuffer.push(stream) } - this.metrics = metrics } createStreamMuxer (init?: StreamMuxerInit | undefined): StreamMuxer { - return new DataChannelMuxer(this.peerConnection, this.streamBuffer, this.protocol, init, this.metrics) + return new DataChannelMuxer(this.init, this.streamBuffer, this.protocol, init) } } @@ -47,16 +56,6 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory { * A libp2p data channel stream muxer */ export class DataChannelMuxer implements StreamMuxer { - /** - * WebRTC Peer Connection - */ - private readonly peerConnection: RTCPeerConnection - - /** - * Optional metrics for this data channel muxer - */ - private readonly metrics?: CounterGroup - /** * Array of streams in the data channel */ @@ -82,24 +81,19 @@ export class DataChannelMuxer implements StreamMuxer { */ sink: Sink, Promise> = nopSink - constructor (peerConnection: RTCPeerConnection, streams: Stream[], readonly protocol: string = '/webrtc', init?: StreamMuxerInit, metrics?: CounterGroup) { + constructor (readonly dataChannelMuxer: DataChannelMuxerFactoryInit, streams: Stream[], readonly protocol: string = '/webrtc', init?: StreamMuxerInit) { /** * Initialized stream muxer */ this.init = init - /** - * WebRTC Peer Connection - */ - this.peerConnection = peerConnection - /** * Fired when a data channel has been added to the connection has been * added by the remote peer. * * {@link https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/datachannel_event} */ - this.peerConnection.ondatachannel = ({ channel }) => { + this.dataChannelMuxer.peerConnection.ondatachannel = ({ channel }) => { const stream = new WebRTCStream({ channel, stat: { @@ -108,12 +102,13 @@ export class DataChannelMuxer implements StreamMuxer { open: 0 } }, + dataChannelOptions: dataChannelMuxer.dataChannelOptions, closeCb: this.wrapStreamEnd(init?.onIncomingStream) }) this.streams.push(stream) if ((init?.onIncomingStream) != null) { - this.metrics?.increment({ incoming_stream: true }) + this.dataChannelMuxer.metrics?.increment({ incoming_stream: true }) init.onIncomingStream(stream) } } @@ -133,9 +128,9 @@ export class DataChannelMuxer implements StreamMuxer { newStream (): Stream { // The spec says the label SHOULD be an empty string: https://github.com/libp2p/specs/blob/master/webrtc/README.md#rtcdatachannel-label - const channel = this.peerConnection.createDataChannel('') + const channel = this.dataChannelMuxer.peerConnection.createDataChannel('') const closeCb = (stream: Stream): void => { - this.metrics?.increment({ stream_end: true }) + this.dataChannelMuxer.metrics?.increment({ stream_end: true }) this.init?.onStreamEnd?.(stream) } const stream = new WebRTCStream({ @@ -146,10 +141,11 @@ export class DataChannelMuxer implements StreamMuxer { open: 0 } }, + dataChannelOptions: this.dataChannelMuxer.dataChannelOptions, closeCb: this.wrapStreamEnd(closeCb) }) this.streams.push(stream) - this.metrics?.increment({ outgoing_stream: true }) + this.dataChannelMuxer.metrics?.increment({ outgoing_stream: true }) return stream } diff --git a/src/private-to-private/handler.ts b/src/private-to-private/handler.ts index cf444f7..73e258a 100644 --- a/src/private-to-private/handler.ts +++ b/src/private-to-private/handler.ts @@ -5,6 +5,7 @@ import pDefer, { type DeferredPromise } from 'p-defer' import { DataChannelMuxerFactory } from '../muxer.js' import { Message } from './pb/message.js' import { readCandidatesUntilConnected, resolveOnConnected } from './util.js' +import type { DataChannelOpts } from '../stream.js' import type { Stream } from '@libp2p/interface-connection' import type { IncomingStreamData } from '@libp2p/interface-registrar' import type { StreamMuxerFactory } from '@libp2p/interface-stream-muxer' @@ -13,14 +14,13 @@ const DEFAULT_TIMEOUT = 30 * 1000 const log = logger('libp2p:webrtc:peer') -export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration } & IncomingStreamData +export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration, dataChannelOptions?: Partial } & IncomingStreamData -export async function handleIncomingStream ({ rtcConfiguration, stream: rawStream }: IncomingStreamOpts): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { +export async function handleIncomingStream ({ rtcConfiguration, dataChannelOptions, stream: rawStream }: IncomingStreamOpts): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { const signal = AbortSignal.timeout(DEFAULT_TIMEOUT) const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message) const pc = new RTCPeerConnection(rtcConfiguration) - const muxerFactory = new DataChannelMuxerFactory(pc) - + const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, dataChannelOptions }) const connectedPromise: DeferredPromise = pDefer() const answerSentPromise: DeferredPromise = pDefer() @@ -86,13 +86,14 @@ export interface ConnectOptions { stream: Stream signal: AbortSignal rtcConfiguration?: RTCConfiguration + dataChannelOptions?: Partial } -export async function initiateConnection ({ rtcConfiguration, signal, stream: rawStream }: ConnectOptions): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { +export async function initiateConnection ({ rtcConfiguration, dataChannelOptions, signal, stream: rawStream }: ConnectOptions): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message) // setup peer connection const pc = new RTCPeerConnection(rtcConfiguration) - const muxerFactory = new DataChannelMuxerFactory(pc) + const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, dataChannelOptions }) const connectedPromise: DeferredPromise = pDefer() resolveOnConnected(pc, connectedPromise) diff --git a/src/private-to-private/transport.ts b/src/private-to-private/transport.ts index fb2e204..69d8554 100644 --- a/src/private-to-private/transport.ts +++ b/src/private-to-private/transport.ts @@ -7,6 +7,7 @@ import { codes } from '../error.js' import { WebRTCMultiaddrConnection } from '../maconn.js' import { initiateConnection, handleIncomingStream } from './handler.js' import { WebRTCPeerListener } from './listener.js' +import type { DataChannelOpts } from '../stream.js' import type { Connection } from '@libp2p/interface-connection' import type { PeerId } from '@libp2p/interface-peer-id' import type { IncomingStreamData, Registrar } from '@libp2p/interface-registrar' @@ -21,6 +22,7 @@ const WEBRTC_CODE = protocols('webrtc').code export interface WebRTCTransportInit { rtcConfiguration?: RTCConfiguration + dataChannel?: Partial } export interface WebRTCTransportComponents { @@ -35,7 +37,7 @@ export class WebRTCTransport implements Transport, Startable { constructor ( private readonly components: WebRTCTransportComponents, - private readonly init: WebRTCTransportInit + private readonly init: WebRTCTransportInit = {} ) { } @@ -123,6 +125,7 @@ export class WebRTCTransport implements Transport, Startable { const { pc, muxerFactory, remoteAddress } = await initiateConnection({ stream: signalingStream, rtcConfiguration: this.init.rtcConfiguration, + dataChannelOptions: this.init.dataChannel, signal: options.signal }) @@ -154,7 +157,8 @@ export class WebRTCTransport implements Transport, Startable { const { pc, muxerFactory, remoteAddress } = await handleIncomingStream({ rtcConfiguration: this.init.rtcConfiguration, connection, - stream + stream, + dataChannelOptions: this.init.dataChannel }) await this.components.upgrader.upgradeInbound(new WebRTCMultiaddrConnection({ diff --git a/src/private-to-public/transport.ts b/src/private-to-public/transport.ts index 30d658d..8499024 100644 --- a/src/private-to-public/transport.ts +++ b/src/private-to-public/transport.ts @@ -9,7 +9,7 @@ import { fromString as uint8arrayFromString } from 'uint8arrays/from-string' import { dataChannelError, inappropriateMultiaddr, unimplemented, invalidArgument } from '../error.js' import { WebRTCMultiaddrConnection } from '../maconn.js' import { DataChannelMuxerFactory } from '../muxer.js' -import { WebRTCStream } from '../stream.js' +import { type DataChannelOpts, WebRTCStream } from '../stream.js' import { isFirefox } from '../util.js' import * as sdp from './sdp.js' import { genUfrag } from './util.js' @@ -52,12 +52,17 @@ export interface WebRTCMetrics { dialerEvents: CounterGroup } +export interface WebRTCTransportDirectInit { + dataChannel?: Partial +} + export class WebRTCDirectTransport implements Transport { private readonly metrics?: WebRTCMetrics private readonly components: WebRTCDirectTransportComponents - - constructor (components: WebRTCDirectTransportComponents) { + private readonly init: WebRTCTransportDirectInit + constructor (components: WebRTCDirectTransportComponents, init: WebRTCTransportDirectInit = {}) { this.components = components + this.init = init if (components.metrics != null) { this.metrics = { dialerEvents: components.metrics.registerCounterGroup('libp2p_webrtc_dialer_events_total', { @@ -185,7 +190,7 @@ export class WebRTCDirectTransport implements Transport { // we pass in undefined for these parameters. const noise = Noise({ prologueBytes: fingerprintsPrologue })() - const wrappedChannel = new WebRTCStream({ channel: handshakeDataChannel, stat: { direction: 'inbound', timeline: { open: 1 } } }) + const wrappedChannel = new WebRTCStream({ channel: handshakeDataChannel, stat: { direction: 'inbound', timeline: { open: 1 } }, dataChannelOptions: this.init.dataChannel }) const wrappedDuplex = { ...wrappedChannel, sink: wrappedChannel.sink.bind(wrappedChannel), @@ -231,7 +236,7 @@ export class WebRTCDirectTransport implements Transport { // Track opened peer connection this.metrics?.dialerEvents.increment({ peer_connection: true }) - const muxerFactory = new DataChannelMuxerFactory(peerConnection, this.metrics?.dialerEvents) + const muxerFactory = new DataChannelMuxerFactory({ peerConnection, metrics: this.metrics?.dialerEvents, dataChannelOptions: this.init.dataChannel }) // For outbound connections, the remote is expected to start the noise handshake. // Therefore, we need to secure an inbound noise connection from the remote. diff --git a/src/stream.ts b/src/stream.ts index cf35eed..340c832 100644 --- a/src/stream.ts +++ b/src/stream.ts @@ -4,6 +4,7 @@ import merge from 'it-merge' import { pipe } from 'it-pipe' import { pushable } from 'it-pushable' import defer, { type DeferredPromise } from 'p-defer' +import { pEvent } from 'p-event' import { Uint8ArrayList } from 'uint8arraylist' import { Message } from './pb/message.js' import type { Stream, StreamStat, Direction } from '@libp2p/interface-connection' @@ -24,6 +25,12 @@ export function defaultStat (dir: Direction): StreamStat { } } +export interface DataChannelOpts { + maxMessageSize: number + maxBufferedAmount: number + bufferedAmountLowEventTimeout: number +} + interface StreamInitOpts { /** * The network channel used for bidirectional peer-to-peer transfers of @@ -47,6 +54,11 @@ interface StreamInitOpts { * Callback to invoke when the stream is closed. */ closeCb?: (stream: WebRTCStream) => void + + /** + * Data channel options + */ + dataChannelOptions?: Partial } /* @@ -151,6 +163,15 @@ class StreamState { } } +// Max message size that can be sent to the DataChannel +const MAX_MESSAGE_SIZE = 16 * 1024 + +// How much can be buffered to the DataChannel at once +const MAX_BUFFERED_AMOUNT = 16 * 1024 * 1024 + +// How long time we wait for the 'bufferedamountlow' event to be emitted +const BUFFERED_AMOUNT_LOW_TIMEOUT = 30 * 1000 + export class WebRTCStream implements Stream { /** * Unique identifier for a stream @@ -177,6 +198,11 @@ export class WebRTCStream implements Stream { */ streamState = new StreamState() + /** + * Data channel options + */ + dataChannelOptions: DataChannelOpts + /** * Read unwrapped protobuf data from the underlying datachannel. * _src is exposed to the user via the `source` getter to . @@ -214,8 +240,14 @@ export class WebRTCStream implements Stream { this.channel = opts.channel this.channel.binaryType = 'arraybuffer' this.id = this.channel.label - this.stat = opts.stat + this.dataChannelOptions = { + bufferedAmountLowEventTimeout: opts.dataChannelOptions?.bufferedAmountLowEventTimeout ?? BUFFERED_AMOUNT_LOW_TIMEOUT, + maxBufferedAmount: opts.dataChannelOptions?.maxBufferedAmount ?? MAX_BUFFERED_AMOUNT, + maxMessageSize: opts.dataChannelOptions?.maxMessageSize ?? MAX_MESSAGE_SIZE + } + this.closeCb = opts.closeCb + switch (this.channel.readyState) { case 'open': this.opened.resolve() @@ -313,10 +345,25 @@ export class WebRTCStream implements Stream { if (this.streamState.isWriteClosed()) { return } + + if (this.channel.bufferedAmount > this.dataChannelOptions.maxBufferedAmount) { + await pEvent(this.channel, 'bufferedamountlow', { timeout: this.dataChannelOptions.bufferedAmountLowEventTimeout }).catch((e) => { + this.close() + throw new Error('Timed out waiting for DataChannel buffer to clear') + }) + } + const msgbuf = Message.encode({ message: buf.subarray() }) const sendbuf = lengthPrefixed.encode.single(msgbuf) - this.channel.send(sendbuf.subarray()) + while (sendbuf.length > 0) { + if (sendbuf.length <= this.dataChannelOptions.maxMessageSize) { + this.channel.send(sendbuf.subarray()) + break + } + this.channel.send(sendbuf.subarray(0, this.dataChannelOptions.maxMessageSize)) + sendbuf.consume(this.dataChannelOptions.maxMessageSize) + } } } diff --git a/test/stream.spec.ts b/test/stream.spec.ts new file mode 100644 index 0000000..2bfaa0c --- /dev/null +++ b/test/stream.spec.ts @@ -0,0 +1,108 @@ +/* eslint-disable @typescript-eslint/consistent-type-assertions */ + +import { expect } from 'aegir/chai' +import * as lengthPrefixed from 'it-length-prefixed' +import { pushable } from 'it-pushable' +import { Message } from '../src/pb/message' +import * as underTest from '../src/stream' + +const mockDataChannel = (opts: { send: (bytes: Uint8Array) => void, bufferedAmount?: number }): RTCDataChannel => { + return { + readyState: 'open', + close: () => {}, + addEventListener: (_type: string, _listener: () => void) => {}, + removeEventListener: (_type: string, _listener: () => void) => {}, + ...opts + } as RTCDataChannel +} + +const MAX_MESSAGE_SIZE = 16 * 1024 + +describe('Max message size', () => { + it(`sends messages smaller or equal to ${MAX_MESSAGE_SIZE} bytes in one`, async () => { + const sent: Uint8Array[] = [] + const data = new Uint8Array(MAX_MESSAGE_SIZE - 5) + const p = pushable() + + // Make sure that the data that ought to be sent will result in a message with exactly MAX_MESSAGE_SIZE + const messageLengthEncoded = lengthPrefixed.encode.single(Message.encode({ message: data })).subarray() + expect(messageLengthEncoded.length).eq(MAX_MESSAGE_SIZE) + const webrtcStream = new underTest.WebRTCStream({ + channel: mockDataChannel({ + send: (bytes) => { + sent.push(bytes) + if (p.readableLength === 0) { + webrtcStream.close() + } + } + }), + stat: underTest.defaultStat('outbound') + }) + + p.push(data) + p.end() + await webrtcStream.sink(p) + expect(sent).to.deep.equals([messageLengthEncoded]) + }) + + it(`sends messages greater than ${MAX_MESSAGE_SIZE} bytes in parts`, async () => { + const sent: Uint8Array[] = [] + const data = new Uint8Array(MAX_MESSAGE_SIZE - 4) + const p = pushable() + + // Make sure that the data that ought to be sent will result in a message with exactly MAX_MESSAGE_SIZE + 1 + const messageLengthEncoded = lengthPrefixed.encode.single(Message.encode({ message: data })).subarray() + expect(messageLengthEncoded.length).eq(MAX_MESSAGE_SIZE + 1) + + const webrtcStream = new underTest.WebRTCStream({ + channel: mockDataChannel({ + send: (bytes) => { + sent.push(bytes) + if (p.readableLength === 0) { + webrtcStream.close() + } + } + }), + stat: underTest.defaultStat('outbound') + }) + + p.push(data) + p.end() + await webrtcStream.sink(p) + + // Message is sent in two parts + expect(sent).to.deep.equals([messageLengthEncoded.subarray(0, messageLengthEncoded.length - 1), messageLengthEncoded.subarray(messageLengthEncoded.length - 1)]) + }) + + it('closes the stream if bufferamountlow timeout', async () => { + const MAX_BUFFERED_AMOUNT = 16 * 1024 * 1024 + 1 + const timeout = 2000 + let closed = false + const webrtcStream = new underTest.WebRTCStream({ + dataChannelOptions: { bufferedAmountLowEventTimeout: timeout }, + channel: mockDataChannel({ + send: () => { + throw new Error('Expected to not send') + }, + bufferedAmount: MAX_BUFFERED_AMOUNT + }), + stat: underTest.defaultStat('outbound'), + closeCb: () => { + closed = true + } + }) + + const p = pushable() + p.push(new Uint8Array(1)) + p.end() + + const t0 = Date.now() + + await expect(webrtcStream.sink(p)).to.eventually.be.rejected + .with.property('message', 'Timed out waiting for DataChannel buffer to clear') + const t1 = Date.now() + expect(t1 - t0).greaterThan(timeout) + expect(t1 - t0).lessThan(timeout + 1000) // Some upper bound + expect(closed).true() + }) +})