// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import {env, InferenceSession} from 'onnxruntime-common'; import {OrtWasmMessage, SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import * as core from './wasm-core-impl'; import {initializeWebAssembly} from './wasm-factory'; import {importProxyWorker} from './wasm-utils-import'; const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined'; let proxyWorker: Worker|undefined; let initializing = false; let initialized = false; let aborted = false; let temporaryObjectUrl: string|undefined; type PromiseCallbacks = [resolve: (result: T) => void, reject: (reason: unknown) => void]; let initWasmCallbacks: PromiseCallbacks; const queuedCallbacks: Map>> = new Map(); const enqueueCallbacks = (type: OrtWasmMessage['type'], callbacks: PromiseCallbacks): void => { const queue = queuedCallbacks.get(type); if (queue) { queue.push(callbacks); } else { queuedCallbacks.set(type, [callbacks]); } }; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { throw new Error('worker not ready'); } }; const onProxyWorkerMessage = (ev: MessageEvent): void => { switch (ev.data.type) { case 'init-wasm': initializing = false; if (ev.data.err) { aborted = true; initWasmCallbacks[1](ev.data.err); } else { initialized = true; initWasmCallbacks[0](); } if (temporaryObjectUrl) { URL.revokeObjectURL(temporaryObjectUrl); temporaryObjectUrl = undefined; } break; case 'init-ep': case 'copy-from': case 'create': case 'release': case 'run': case 'end-profiling': { const callbacks = queuedCallbacks.get(ev.data.type)!; if (ev.data.err) { callbacks.shift()![1](ev.data.err); } else { callbacks.shift()![0](ev.data.out!); } break; } default: } }; export const initializeWebAssemblyAndOrtRuntime = async(): Promise => { if (initialized) { return; } if (initializing) { throw new Error('multiple calls to \'initWasm()\' detected.'); } if (aborted) { throw new Error('previous call to \'initWasm()\' failed.'); } initializing = true; if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { return new Promise((resolve, reject) => { proxyWorker?.terminate(); void importProxyWorker().then(([objectUrl, worker]) => { try { proxyWorker = worker; proxyWorker.onerror = (ev: ErrorEvent) => reject(ev); proxyWorker.onmessage = onProxyWorkerMessage; initWasmCallbacks = [resolve, reject]; const message: OrtWasmMessage = {type: 'init-wasm', in : env}; proxyWorker.postMessage(message); temporaryObjectUrl = objectUrl; } catch (e) { reject(e); } }, reject); }); } else { try { await initializeWebAssembly(env.wasm); await core.initRuntime(env); initialized = true; } catch (e) { aborted = true; throw e; } finally { initializing = false; } } }; export const initializeOrtEp = async(epName: string): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('init-ep', [resolve, reject]); const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env}}; proxyWorker!.postMessage(message); }); } else { await core.initEp(env, epName); } }; export const copyFromExternalBuffer = async(buffer: Uint8Array): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('copy-from', [resolve, reject]); const message: OrtWasmMessage = {type: 'copy-from', in : {buffer}}; proxyWorker!.postMessage(message, [buffer.buffer]); }); } else { return core.copyFromExternalBuffer(buffer); } }; export const createSession = async(model: SerializableInternalBuffer|Uint8Array, options?: InferenceSession.SessionOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { // check unsupported options if (options?.preferredOutputLocation) { throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); } ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('create', [resolve, reject]); const message: OrtWasmMessage = {type: 'create', in : {model, options: {...options}}}; const transferable: Transferable[] = []; if (model instanceof Uint8Array) { transferable.push(model.buffer); } proxyWorker!.postMessage(message, transferable); }); } else { return core.createSession(model, options); } }; export const releaseSession = async(sessionId: number): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('release', [resolve, reject]); const message: OrtWasmMessage = {type: 'release', in : sessionId}; proxyWorker!.postMessage(message); }); } else { core.releaseSession(sessionId); } }; export const run = async( sessionId: number, inputIndices: number[], inputs: TensorMetadata[], outputIndices: number[], outputs: Array, options: InferenceSession.RunOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { // check inputs location if (inputs.some(t => t[3] !== 'cpu')) { throw new Error('input tensor on GPU is not supported for proxy.'); } // check outputs location if (outputs.some(t => t)) { throw new Error('pre-allocated output tensor is not supported for proxy.'); } ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('run', [resolve, reject]); const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs)); }); } else { return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options); } }; export const endProfiling = async(sessionId: number): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('end-profiling', [resolve, reject]); const message: OrtWasmMessage = {type: 'end-profiling', in : sessionId}; proxyWorker!.postMessage(message); }); } else { core.endProfiling(sessionId); } };