Skip to content

Commit

Permalink
fix: request abort signal
Browse files Browse the repository at this point in the history
  • Loading branch information
ronag committed May 7, 2024
1 parent 08363f0 commit fade96c
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 8 deletions.
29 changes: 21 additions & 8 deletions lib/api/api-request.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ const { InvalidArgumentError } = require('../core/errors')
const util = require('../core/util')
const { getResolveErrorBodyCallback } = require('./util')
const { AsyncResource } = require('node:async_hooks')
const { addSignal, removeSignal } = require('./abort-signal')

class RequestHandler extends AsyncResource {
constructor (opts, callback) {
Expand Down Expand Up @@ -56,19 +55,28 @@ class RequestHandler extends AsyncResource {
this.onInfo = onInfo || null
this.throwOnError = throwOnError
this.highWaterMark = highWaterMark
this.signal = signal

if (util.isStream(body)) {
body.on('error', (err) => {
this.onError(err)
})
}

addSignal(this, signal)
if (this.signal) {
this.removeAbortListener = util.addAbortListener(this.signal, () => {
if (this.res) {
this.res.destroy(this.signal.reason)
} else if (this.abort) {
this.abort(this.signal.reason)
}
})
}
}

onConnect (abort, context) {
if (this.reason) {
abort(this.reason)
if (this.signal && this.signal.aborted) {
abort(this.signal.reason)
return
}

Expand All @@ -95,6 +103,13 @@ class RequestHandler extends AsyncResource {
const contentLength = parsedHeaders['content-length']
const body = new Readable({ resume, abort, contentType, contentLength, highWaterMark })

if (this.removeAbortListener) {
// TODO (fix): 'close' is sufficient but breaks tests.
body
.on('end', this.removeAbortListener)
.on('error', this.removeAbortListener)
}

this.callback = null
this.res = body
if (callback !== null) {
Expand Down Expand Up @@ -123,8 +138,6 @@ class RequestHandler extends AsyncResource {
onComplete (trailers) {
const { res } = this

removeSignal(this)

util.parseHeaders(trailers, this.trailers)

res.push(null)
Expand All @@ -133,8 +146,6 @@ class RequestHandler extends AsyncResource {
onError (err) {
const { res, callback, body, opaque } = this

removeSignal(this)

if (callback) {
// TODO: Does this need queueMicrotask?
this.callback = null
Expand All @@ -149,6 +160,8 @@ class RequestHandler extends AsyncResource {
queueMicrotask(() => {
util.destroy(res, err)
})
} else if (this.removeAbortListener) {
this.removeAbortListener()
}

if (body) {
Expand Down
76 changes: 76 additions & 0 deletions test/request-signal.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
'use strict'

const { createServer } = require('node:http')
const { test, after } = require('node:test')
const { tspl } = require('@matteo.collina/tspl')
const { request } = require('..')

test('pre abort signal w/ reason', async (t) => {
t = tspl(t, { plan: 1 })

const server = createServer((req, res) => {
res.end('asd')
})
after(() => server.close())

server.listen(0, async () => {
const ac = new AbortController()
const _err = new Error()
ac.abort(_err)
try {
await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal })
} catch (err) {
t.equal(err, _err)
}
})
await t.completed
})

test('post abort signal', async (t) => {
t = tspl(t, { plan: 1 })

const server = createServer((req, res) => {
res.end('asd')
})
after(() => server.close())

server.listen(0, async () => {
const ac = new AbortController()
const ures = await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal })
ac.abort()
try {
/* eslint-disable-next-line no-unused-vars */
for await (const chunk of ures.body) {
// Do nothing...
}
} catch (err) {
t.equal(err.name, 'AbortError')
}
})
await t.completed
})

test('post abort signal w/ reason', async (t) => {
t = tspl(t, { plan: 1 })

const server = createServer((req, res) => {
res.end('asd')
})
after(() => server.close())

server.listen(0, async () => {
const ac = new AbortController()
const _err = new Error()
const ures = await request(`http://0.0.0.0:${server.address().port}`, { signal: ac.signal })
ac.abort(_err)
try {
/* eslint-disable-next-line no-unused-vars */
for await (const chunk of ures.body) {
// Do nothing...
}
} catch (err) {
t.equal(err, _err)
}
})
await t.completed
})

0 comments on commit fade96c

Please sign in to comment.