Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] rewrite backend resolve to allow multiple EPs #19735

Merged
merged 6 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 90 additions & 31 deletions js/common/lib/backend-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {Backend} from './backend.js';
import {InferenceSession} from './inference-session.js';

interface BackendInfo {
backend: Backend;
Expand All @@ -10,6 +11,7 @@ interface BackendInfo {
initPromise?: Promise<void>;
initialized?: boolean;
aborted?: boolean;
error?: string;
}

const backends: Map<string, BackendInfo> = new Map();
Expand Down Expand Up @@ -60,43 +62,100 @@ export const registerBackend = (name: string, backend: Backend, priority: number
};

/**
* Resolve backend by specified hints.
* Try to resolve and initialize a backend.
*
* @param backendHints - a list of execution provider names to lookup. If omitted use registered backends as list.
* @returns a promise that resolves to the backend.
* @param backendName - the name of the backend.
* @returns the backend instance if resolved and initialized successfully, or an error message if failed.
*/
const tryResolveAndInitializeBackend = async(backendName: string): Promise<Backend|string> => {
const backendInfo = backends.get(backendName);
if (!backendInfo) {
return 'backend not found.';
}

if (backendInfo.initialized) {
return backendInfo.backend;
} else if (backendInfo.aborted) {
return backendInfo.error!;
} else {
const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init(backendName);
}
await backendInfo.initPromise;
backendInfo.initialized = true;
return backendInfo.backend;
} catch (e) {
if (!isInitializing) {
backendInfo.error = `${e}`;
backendInfo.aborted = true;
}
return backendInfo.error!;
} finally {
delete backendInfo.initPromise;
}
}
};

/**
* Resolve execution providers from the specific session options.
*
* @param options - the session options object.
* @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with
* filtered EP list.
*
* @ignore
*/
export const resolveBackend = async(backendHints: readonly string[]): Promise<Backend> => {
const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;
const errors = [];
for (const backendName of backendNames) {
const backendInfo = backends.get(backendName);
if (backendInfo) {
if (backendInfo.initialized) {
return backendInfo.backend;
} else if (backendInfo.aborted) {
continue; // current backend is unavailable; try next
}
export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions):
Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => {
// extract backend hints from session options
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;

const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init(backendName);
// try to resolve and initialize all requested backends
let backend: Backend|undefined;
const errors = [];
const availableBackendNames = new Set<string>();
for (const backendName of backendNames) {
const resolveResult = await tryResolveAndInitializeBackend(backendName);
if (typeof resolveResult === 'string') {
errors.push({name: backendName, err: resolveResult});
} else {
if (!backend) {
backend = resolveResult;
}
if (backend === resolveResult) {
availableBackendNames.add(backendName);
}
}
await backendInfo.initPromise;
backendInfo.initialized = true;
return backendInfo.backend;
} catch (e) {
if (!isInitializing) {
errors.push({name: backendName, err: e});
}

// if no backend is available, throw error.
if (!backend) {
throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
}

// for each explicitly requested backend, if it's not available, output warning message.
for (const {name, err} of errors) {
if (backendHints.includes(name)) {
// eslint-disable-next-line no-console
console.warn(`removing requested execution provider "${
name}" from session options because it is not available: ${err}`);
}
backendInfo.aborted = true;
} finally {
delete backendInfo.initPromise;
}
}
}

throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
};
const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name));

return [
backend, new Proxy(options, {
get: (target, prop) => {
if (prop === 'executionProviders') {
return filteredEps;
}
return Reflect.get(target, prop);
}
})
];
};
10 changes: 4 additions & 6 deletions js/common/lib/inference-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {resolveBackendAndExecutionProviders} from './backend-impl.js';
import {InferenceSessionHandler} from './backend.js';
import {InferenceSession as InferenceSessionInterface} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
Expand Down Expand Up @@ -195,11 +195,9 @@ export class InferenceSession implements InferenceSessionInterface {
throw new TypeError('Unexpected argument[0]: must be \'path\' or \'buffer\'.');
}

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options);
// resolve backend, update session options with validated EPs, and create session handler
const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs);
TRACE_FUNC_END();
return new InferenceSession(handler);
}
Expand Down
11 changes: 5 additions & 6 deletions js/common/lib/training-session-impl.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {resolveBackend} from './backend-impl.js';
import {resolveBackendAndExecutionProviders} from './backend-impl.js';
import {SessionHandler, TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
Expand Down Expand Up @@ -55,13 +55,12 @@ export class TrainingSession implements TrainingSessionInterface {
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};

// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
// resolve backend, update session options with validated EPs, and create session handler
const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel,
optionsWithValidatedEPs);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
Expand Down
17 changes: 10 additions & 7 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,17 @@ export interface OrtWasmModule extends EmscriptenModule {

// #region JSEP
/**
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime.
* This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per backend.
* This function initializes Asyncify support.
* If name is 'webgpu', also initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
*/
jsepInit?
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction,
captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void;
jsepInit?(name: 'webgpu', initParams: [
backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, releaseKernel: JSEP.ReleaseKernelFunction,
run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, captureEnd: JSEP.CaptureEndFunction,
replay: JSEP.ReplayFunction
]): void;
jsepInit?(name: 'webnn', initParams?: never): void;

/**
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
Expand Down
32 changes: 24 additions & 8 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,39 @@ class ComputeContextImpl implements ComputeContext {
/**
* Initialize JSEP with WebGPU backend.
*
* This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
* This function expects:
* This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for
* each of the following EPs if they are specified:
* - "webgpu"
* - "webnn"
*
* For WebGPU, this function expects:
* - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
*
* For WebNN, this function expects:
* - WebNN is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* - WebNN is available in current environment. (navigator.ml is not undefined)
*
* If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
* 'webgpu' backend.
* 'webgpu'/'webnn' backend.
*
* @param name - the name of the EP, either "webgpu" or "webnn"
* @param module - the ORT WebAssembly module
* @param env - the ORT environment variable (ort.env)
* @param gpuAdapter - the pre-created GPU adapter
*/
export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise<void> => {
export const init =
async(name: 'webgpu'|'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter): Promise<void> => {
const jsepInit = module.jsepInit;
if (!jsepInit) {
throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
}

const backend = new WebGpuBackend();
await backend.initialize(env, gpuAdapter);
if (name === 'webgpu') {
const backend = new WebGpuBackend();
await backend.initialize(env, gpuAdapter!);

jsepInit(
jsepInit('webgpu', [
// backend
backend,

Expand Down Expand Up @@ -208,5 +220,9 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
// jsepCaptureEnd
() => backend.captureEnd(),
// jsepReplay
() => backend.replay());
() => backend.replay()
]);
} else {
jsepInit('webnn');
}
};
2 changes: 1 addition & 1 deletion js/web/lib/wasm/proxy-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ export const createSession =
ensureWorker();
return new Promise<SerializableSessionMetadata>((resolve, reject) => {
enqueueCallbacks('create', [resolve, reject]);
const message: OrtWasmMessage = {type: 'create', in : {model, options}};
const message: OrtWasmMessage = {type: 'create', in : {model, options: {...options}}};
const transferable: Transferable[] = [];
if (model instanceof Uint8Array) {
transferable.push(model.buffer);
Expand Down
45 changes: 27 additions & 18 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,36 @@ export const initRuntime = async(env: Env): Promise<void> => {
* @param epName
*/
export const initEp = async(env: Env, epName: string): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
}
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error(
'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
}
if (!BUILD_DEFS.DISABLE_WEBGPU) {
// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
const initJsep = require('./jsep/init').init;

if (!env.wasm.simd) {
throw new Error(
'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP');
}
if (epName === 'webgpu') {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
}
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error(
'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
}

// init JSEP if available
if (!env.wasm.simd) {
throw new Error(
'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP');
}

// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
const initJsep = require('./jsep/init').init;
await initJsep(getInstance(), env, adapter);
await initJsep('webgpu', getInstance(), env, adapter);
}
if (epName === 'webnn') {
// perform WebNN availability check
if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) {
throw new Error('WebNN is not supported in current environment');
}

await initJsep('webnn', getInstance(), env);
}
}
};

Expand Down
Loading
Loading