Skip to content

Commit

Permalink
refactor: move code duplication to shared handler
Browse files Browse the repository at this point in the history
Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>
  • Loading branch information
ckohen and Jiralite committed Mar 21, 2023
1 parent 03e438b commit 7a2da4c
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 160 deletions.
139 changes: 139 additions & 0 deletions packages/rest/__tests__/BurstHandler.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* eslint-disable id-length */
/* eslint-disable promise/prefer-await-to-then */
import { performance } from 'node:perf_hooks';
import { MockAgent, setGlobalDispatcher } from 'undici';
import type { Interceptable, MockInterceptor } from 'undici/types/mock-interceptor';
import { beforeEach, afterEach, test, expect, vitest } from 'vitest';
import { DiscordAPIError, HTTPError, RateLimitError, REST, BurstHandlerMajorIdKey } from '../src/index.js';
import { BurstHandler } from '../src/lib/handlers/BurstHandler.js';
import { genPath } from './util.js';

const callbackKey = `Global(POST:/interactions/:id/:token/callback):${BurstHandlerMajorIdKey}`;
const callbackPath = new RegExp(genPath('/interactions/[0-9]{17,19}/.+/callback'));

const api = new REST();

let mockAgent: MockAgent;
let mockPool: Interceptable;

beforeEach(() => {
mockAgent = new MockAgent();
mockAgent.disableNetConnect();
setGlobalDispatcher(mockAgent);

mockPool = mockAgent.get('https://discord.com');
api.setAgent(mockAgent);
});

afterEach(async () => {
await mockAgent.close();
});

// @discordjs/rest uses the `content-type` header to detect whether to parse
// the response as JSON or as an ArrayBuffer.
const responseOptions: MockInterceptor.MockResponseOptions = {
headers: {
'content-type': 'application/json',
},
};

test('Interaction callback creates burst handler', async () => {
mockPool.intercept({ path: callbackPath, method: 'POST' }).reply(200);

expect(api.requestManager.handlers.get(callbackKey)).toBe(undefined);
expect(
await api.post('/interactions/1234567890123456789/totallyarealtoken/callback', {
auth: false,
body: { type: 4, data: { content: 'Reply' } },
}),
).toBeInstanceOf(Uint8Array);
expect(api.requestManager.handlers.get(callbackKey)).toBeInstanceOf(BurstHandler);
});

test('Requests are handled in bursts', async () => {
mockPool.intercept({ path: callbackPath, method: 'POST' }).reply(200).delay(100).times(3);

// Return the current time on these results as their response does not indicate anything
const [a, b, c] = await Promise.all([
api
.post('/interactions/1234567890123456789/totallyarealtoken/callback', {
auth: false,
body: { type: 4, data: { content: 'Reply1' } },
})
.then(() => performance.now()),
api
.post('/interactions/2345678901234567890/anotherveryrealtoken/callback', {
auth: false,
body: { type: 4, data: { content: 'Reply2' } },
})
.then(() => performance.now()),
api
.post('/interactions/3456789012345678901/nowaytheresanotherone/callback', {
auth: false,
body: { type: 4, data: { content: 'Reply3' } },
})
.then(() => performance.now()),
]);

expect(b - a).toBeLessThan(10);
expect(c - a).toBeLessThan(10);
});

test('Handle 404', async () => {
mockPool
.intercept({ path: callbackPath, method: 'POST' })
.reply(404, { message: 'Unknown interaction', code: 10_062 }, responseOptions);

const promise = api.post('/interactions/1234567890123456788/definitelynotarealinteraction/callback', {
auth: false,
body: { type: 4, data: { content: 'Malicious' } },
});
await expect(promise).rejects.toThrowError('Unknown interaction');
await expect(promise).rejects.toBeInstanceOf(DiscordAPIError);
});

let unexpected429 = true;
test('Handle unexpected 429', async () => {
mockPool
.intercept({
path: callbackPath,
method: 'POST',
})
.reply(() => {
if (unexpected429) {
unexpected429 = false;
return {
statusCode: 429,
data: '',
responseOptions: {
headers: {
'retry-after': '1',
via: '1.1 google',
},
},
};
}

return {
statusCode: 200,
data: { test: true },
responseOptions,
};
})
.times(2);

const previous = performance.now();
let firstResolvedTime: number;
const unexpectedLimit = api
.post('/interactions/1234567890123456789/totallyarealtoken/callback', {
auth: false,
body: { type: 4, data: { content: 'Reply' } },
})
.then((res) => {
firstResolvedTime = performance.now();
return res;
});

expect(await unexpectedLimit).toStrictEqual({ test: true });
expect(performance.now()).toBeGreaterThanOrEqual(previous + 1_000);
});
6 changes: 4 additions & 2 deletions packages/rest/__tests__/REST.test.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import { Buffer } from 'node:buffer';
import { Buffer, File as NativeFile } from 'node:buffer';
import { URLSearchParams } from 'node:url';
import { DiscordSnowflake } from '@sapphire/snowflake';
import type { Snowflake } from 'discord-api-types/v10';
import { Routes } from 'discord-api-types/v10';
import type { FormData } from 'undici';
import { File, MockAgent, setGlobalDispatcher } from 'undici';
import { File as UndiciFile, MockAgent, setGlobalDispatcher } from 'undici';
import type { Interceptable, MockInterceptor } from 'undici/types/mock-interceptor';
import { beforeEach, afterEach, test, expect } from 'vitest';
import { REST } from '../src/index.js';
import { genPath } from './util.js';

const File = NativeFile ?? UndiciFile;

const newSnowflake: Snowflake = DiscordSnowflake.generate().toString();

const api = new REST().setToken('A-Very-Fake-Token');
Expand Down
34 changes: 0 additions & 34 deletions packages/rest/src/lib/RequestManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,6 @@ export interface RequestManager {
(<S extends string | symbol>(event?: Exclude<S, keyof RestEvents>) => this);
}

/**
* Invalid request limiting is done on a per-IP basis, not a per-token basis.
* The best we can do is track invalid counts process-wide (on the theory that
* users could have multiple bots run from one process) rather than per-bot.
* Therefore, store these at file scope here rather than in the client's
* RESTManager object.
*/
let invalidCount = 0;
let invalidCountResetTime: number | null = null;

/**
* Represents the class that manages handlers for endpoints
*/
Expand Down Expand Up @@ -511,30 +501,6 @@ export class RequestManager extends EventEmitter {
clearInterval(this.handlerTimer);
}

/**
* Increment the invalid request count and emit warning if necessary
*
* @internal
*/
public incrementInvalidCount() {
if (!invalidCountResetTime || invalidCountResetTime < Date.now()) {
invalidCountResetTime = Date.now() + 1_000 * 60 * 10;
invalidCount = 0;
}

invalidCount++;

const emitInvalid =
this.options.invalidRequestWarningInterval > 0 && invalidCount % this.options.invalidRequestWarningInterval === 0;
if (emitInvalid) {
// Let library users know periodically about invalid requests
this.emit(RESTEvents.InvalidRequestWarning, {
count: invalidCount,
remainingTime: invalidCountResetTime - Date.now(),
});
}
}

/**
* Generates route data for an endpoint:method
*
Expand Down
76 changes: 19 additions & 57 deletions packages/rest/src/lib/handlers/BurstHandler.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import { setTimeout, clearTimeout } from 'node:timers';
import { setTimeout as sleep } from 'node:timers/promises';
import { request, type Dispatcher } from 'undici';
import type { Dispatcher } from 'undici';
import type { RequestOptions } from '../REST.js';
import type { HandlerRequestData, RequestManager, RouteData } from '../RequestManager.js';
import { DiscordAPIError, type DiscordErrorData, type OAuthErrorData } from '../errors/DiscordAPIError.js';
import { HTTPError } from '../errors/HTTPError.js';
import { RESTEvents } from '../utils/constants.js';
import { onRateLimit, parseHeader, parseResponse, shouldRetry } from '../utils/utils.js';
import type { IHandler, PolyFillAbortSignal } from './IHandler.js';
import { onRateLimit, parseHeader } from '../utils/utils.js';
import type { IHandler } from './IHandler.js';
import { handleErrors, incrementInvalidCount, makeNetworkRequest } from './Shared.js';

/**
* The structure used to handle burst requests for a given bucket.
* Burst requests have no ratelimit handling but allow for pre- and post-processing
* of data in the same manner as sequentially queued requests.
*
* @remarks
* This queue may still emit a rate limit error if an unexpected 429 is hit</info>
* This queue may still emit a rate limit error if an unexpected 429 is hit
*/
export class BurstHandler implements IHandler {
/**
Expand Down Expand Up @@ -64,10 +62,10 @@ export class BurstHandler implements IHandler {
}

/**
* The method that actually makes the request to the api, and updates info about the bucket accordingly
* The method that actually makes the request to the API, and updates info about the bucket accordingly
*
* @param routeId - The generalized api route with literal ids for major parameters
* @param url - The fully resolved url to make the request to
* @param routeId - The generalized API route with literal ids for major parameters
* @param url - The fully resolved URL to make the request to
* @param options - The fetch options needed to make the request
* @param requestData - Extra data from the user's request needed for errors and additional processing
* @param retries - The number of retries this request has already attempted (recursion)
Expand All @@ -81,32 +79,12 @@ export class BurstHandler implements IHandler {
): Promise<Dispatcher.ResponseData> {
const method = options.method ?? 'get';

const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), this.manager.options.timeout).unref();
if (requestData.signal) {
// The type polyfill is required because Node.js's types are incomplete.
const signal = requestData.signal as PolyFillAbortSignal;
// If the user signal was aborted, abort the controller, else abort the local signal.
// The reason why we don't re-use the user's signal, is because users may use the same signal for multiple
// requests, and we do not want to cause unexpected side-effects.
if (signal.aborted) controller.abort();
else signal.addEventListener('abort', () => controller.abort());
}
const res = await makeNetworkRequest(this.manager, routeId, url, options, requestData, retries);

let res: Dispatcher.ResponseData;
try {
res = await request(url, { ...options, signal: controller.signal });
} catch (error: unknown) {
if (!(error instanceof Error)) throw error;
// Retry the specified number of times if needed
if (shouldRetry(error) && retries !== this.manager.options.retries) {
// eslint-disable-next-line no-param-reassign
return await this.runRequest(routeId, url, options, requestData, ++retries);
}

throw error;
} finally {
clearTimeout(timeout);
// Retry requested
if (res === null) {
// eslint-disable-next-line no-param-reassign
return this.runRequest(routeId, url, options, requestData, ++retries);
}

const status = res.statusCode;
Expand All @@ -118,7 +96,7 @@ export class BurstHandler implements IHandler {

// Count the invalid requests
if (status === 401 || status === 403 || status === 429) {
this.manager.incrementInvalidCount();
incrementInvalidCount(this.manager);
}

if (status >= 200 && status < 300) {
Expand Down Expand Up @@ -151,35 +129,19 @@ export class BurstHandler implements IHandler {
].join('\n'),
);

// We are bypassing all other limits, but an encountered limit should be respected (it's probably a non-punished ratelimit anyways)
// We are bypassing all other limits, but an encountered limit should be respected (it's probably a non-punished rate limit anyways)
await sleep(retryAfter);

// Since this is not a server side issue, the next request should pass, so we don't bump the retries counter
return this.runRequest(routeId, url, options, requestData, retries);
} else if (status >= 500 && status < 600) {
// Retry the specified number of times for possible server side issues
if (retries !== this.manager.options.retries) {
} else {
const handled = await handleErrors(this.manager, res, method, url, requestData, retries);
if (handled === null) {
// eslint-disable-next-line no-param-reassign
return this.runRequest(routeId, url, options, requestData, ++retries);
}

// We are out of retries, throw an error
throw new HTTPError(status, method, url, requestData);
} else {
// Handle possible malformed requests
if (status >= 400 && status < 500) {
// If we receive this status code, it means the token we had is no longer valid.
if (status === 401 && requestData.auth) {
this.manager.setToken(null!);
}

// The request will not succeed for some reason, parse the error returned from the api
const data = (await parseResponse(res)) as DiscordErrorData | OAuthErrorData;
// throw the API error
throw new DiscordAPIError(data, 'code' in data ? data.code : data.error, status, method, url, requestData);
}

return res;
return handled;
}
}
}
Loading

0 comments on commit 7a2da4c

Please sign in to comment.