Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] rewrite backend resolve to allow multiple EPs #19735

Merged
merged 6 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
webnn support
  • Loading branch information
fs-eire committed Mar 5, 2024
commit b688dee3ac3cbba15b6cb69308c982e63ddc7721
17 changes: 10 additions & 7 deletions js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,17 @@ export interface OrtWasmModule extends EmscriptenModule {

// #region JSEP
/**
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime.
* This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
* This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per backend.
* This function initializes Asyncify support.
* If name is 'webgpu', also initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
*/
jsepInit?
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction,
captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void;
jsepInit?(name: 'webgpu', initParams: [
backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, releaseKernel: JSEP.ReleaseKernelFunction,
run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, captureEnd: JSEP.CaptureEndFunction,
replay: JSEP.ReplayFunction
]): void;
jsepInit?(name: 'webnn', initParams?: never): void;

/**
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
Expand Down
32 changes: 24 additions & 8 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,39 @@ class ComputeContextImpl implements ComputeContext {
/**
* Initialize JSEP with WebGPU backend.
*
* This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
* This function expects:
* This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for
* each of the following EPs if they are specified:
* - "webgpu"
* - "webnn"
*
* For WebGPU, this function expects:
* - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
*
* For WebNN, this function expects:
* - WebNN is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* - WebNN is available in current environment. (navigator.ml is not undefined)
*
* If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
* 'webgpu' backend.
* 'webgpu'/'webnn' backend.
*
* @param name - the name of the EP, either "webgpu" or "webnn"
* @param module - the ORT WebAssembly module
* @param env - the ORT environment variable (ort.env)
* @param gpuAdapter - the pre-created GPU adapter
*/
export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise<void> => {
export const init =
async(name: 'webgpu'|'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter): Promise<void> => {
const jsepInit = module.jsepInit;
if (!jsepInit) {
throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
}

const backend = new WebGpuBackend();
await backend.initialize(env, gpuAdapter);
if (name === 'webgpu') {
const backend = new WebGpuBackend();
await backend.initialize(env, gpuAdapter!);

jsepInit(
jsepInit('webgpu', [
// backend
backend,

Expand Down Expand Up @@ -208,5 +220,9 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
// jsepCaptureEnd
() => backend.captureEnd(),
// jsepReplay
() => backend.replay());
() => backend.replay()
]);
} else {
jsepInit('webnn');
}
};
45 changes: 27 additions & 18 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,36 @@ export const initRuntime = async(env: Env): Promise<void> => {
* @param epName
*/
export const initEp = async(env: Env, epName: string): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
}
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error(
'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
}
if (!BUILD_DEFS.DISABLE_WEBGPU) {
// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
const initJsep = require('./jsep/init').init;

if (!env.wasm.simd) {
throw new Error(
'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP');
}
if (epName === 'webgpu') {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
}
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error(
'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
}

// init JSEP if available
if (!env.wasm.simd) {
throw new Error(
'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP');
}

// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
const initJsep = require('./jsep/init').init;
await initJsep(getInstance(), env, adapter);
await initJsep('webgpu', getInstance(), env, adapter);
}
if (epName === 'webnn') {
// perform WebNN availability check
if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) {
throw new Error('WebNN is not supported in current environment');
}

await initJsep('webnn', getInstance(), env);
}
}
};

Expand Down
75 changes: 44 additions & 31 deletions onnxruntime/wasm/js_internal_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
*/
Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => {
const files = Module.MountedFiles || (Module.MountedFiles = new Map());
files.set(externalDataFilePath, externalDataFileData);
files.set(externalDataFilePath, externalDataFileData);
};

/**
Expand All @@ -22,21 +22,9 @@ Module['unmountExternalData'] = () => {
};

/**
* init JSEP
* initialize JSEP for asyncify support.
*/
Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => {
Module.jsepBackend = backend;
Module.jsepAlloc = alloc;
Module.jsepFree = free;
Module.jsepCopy = copy;
Module.jsepCopyAsync = copyAsync;
Module.jsepCreateKernel = createKernel;
Module.jsepReleaseKernel = releaseKernel;
Module.jsepRunKernel = runKernel;
Module.jsepCaptureBegin = captureBegin;
Module.jsepCaptureEnd = captureEnd;
Module.jsepReplay = replay;

let jsepInitAsync = () => {
// This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1)
// It removes some overhead in cwarp() and ccall() that we don't need.
//
Expand Down Expand Up @@ -180,20 +168,45 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
() => Module['_OrtBindInput'],
v => Module['_OrtBindInput'] = v);

// expose webgpu backend functions
Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => {
return backend['registerBuffer'](sessionId, index, buffer, size);
};
Module['jsepGetBuffer'] = (dataId) => {
return backend['getBuffer'](dataId);
};
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
return backend['createDownloader'](gpuBuffer, size, type);
};
Module['jsepOnReleaseSession'] = sessionId => {
backend['onReleaseSession'](sessionId);
};
Module['jsepOnRunStart'] = sessionId => {
return backend['onRunStart'](sessionId);
};
// remove this function to make sure it is called only once.
jsepInitAsync = undefined;
};


/**
* initialize JSEP for WebGPU.
*/
Module['jsepInit'] = (name, params) => {
jsepInitAsync?.();

if (name === 'webgpu') {
[Module.jsepBackend, Module.jsepAlloc,
Module.jsepFree,
Module.jsepCopy,
Module.jsepCopyAsync,
Module.jsepCreateKernel,
Module.jsepReleaseKernel,
Module.jsepRunKernel,
Module.jsepCaptureBegin,
Module.jsepCaptureEnd,
Module.jsepReplay] = params;

// expose webgpu backend functions
const backend = Module.jsepBackend;
Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => {
return backend['registerBuffer'](sessionId, index, buffer, size);
};
Module['jsepGetBuffer'] = (dataId) => {
return backend['getBuffer'](dataId);
};
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
return backend['createDownloader'](gpuBuffer, size, type);
};
Module['jsepOnReleaseSession'] = sessionId => {
backend['onReleaseSession'](sessionId);
};
Module['jsepOnRunStart'] = sessionId => {
return backend['onRunStart'](sessionId);
};
}
};
Loading