// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { WebNNBackend } from '../backend-webnn'; import { tensorTypeToTypedArrayConstructor } from '../../wasm-common'; import { LOG_DEBUG } from '../log'; // WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from // WebNN API specification. // https://github.com/webmachinelearning/webnn/issues/677 /// /** * Map from MLOperandDataType to size in bits. Using bits instead of bytes to avoid possible precision loss on int4 and uint4. */ const webnnDataTypeToSize = new Map([ ['float32', 32], ['float16', 16], ['int32', 32], ['uint32', 32], ['int64', 64], ['uint64', 64], ['int8', 8], ['uint8', 8], ['int4', 4], ['uint4', 4], ]); // Convert integer data to an Int32Array buffer. // Supports conversion from int64, uint64, uint32, int8 and uint8 to int32. export const convertDataToInt32 = (data: Uint8Array, dataType: MLOperandDataType): Uint8Array => { if (dataType === 'int32') { return data; } const dataTypeSize = webnnDataTypeToSize.get(dataType); if (!dataTypeSize) { throw new Error(`WebNN backend does not support data type: ${dataType}`); } const bytesPerElement = dataTypeSize / 8; // Make sure the data length is a multiple of the data type size. if (data.byteLength % bytesPerElement !== 0) { throw new Error(`Invalid Uint8Array length - must be a multiple of ${bytesPerElement}.`); } // Convert Uint8Array to original typed array. const numElements = data.byteLength / bytesPerElement; const originalArray = new (tensorTypeToTypedArrayConstructor(dataType))(data.buffer, data.byteOffset, numElements); switch (dataType) { case 'int64': case 'uint64': { // Convert original typed array to Int32Array. const int32Array = new Int32Array(numElements); for (let i = 0; i < numElements; i++) { const value = originalArray[i]; // Check for overflow. if (value > 2147483647n || value < -2147483648n) { throw new Error(`Can not convert int64 data to int32 - value out of range.`); } int32Array[i] = Number(value); } return new Uint8Array(int32Array.buffer); } case 'int8': case 'uint8': case 'uint32': { // Check for overflow. if (dataType === 'uint32') { if (originalArray.some((value) => value > 2147483647)) { throw new Error(`Can not convert uint32 data to int32 - value out of range.`); } } // Convert original typed array to Int32Array. const int32Array = Int32Array.from(originalArray, Number); return new Uint8Array(int32Array.buffer); } default: throw new Error(`Unsupported data conversion from ${dataType} to 'int32'`); } }; // Convert Int32Array data to original integer data buffer. // Supports conversion from int32 to int64, uint64, uint32, int8 and uint8. export const convertInt32ToData = (data: Uint8Array, dataType: MLOperandDataType): Uint8Array => { if (dataType === 'int32') { return data; } // Make sure the data length is a multiple of 4 bytes (Int32Array). if (data.byteLength % 4 !== 0) { throw new Error('Invalid Uint8Array length - must be a multiple of 4 (int32).'); } // Convert Uint8Array to Int32Array. const numElements = data.byteLength / 4; const int32Array = new Int32Array(data.buffer, data.byteOffset, numElements); switch (dataType) { case 'int64': { const bigInt64Array = BigInt64Array.from(int32Array, BigInt); return new Uint8Array(bigInt64Array.buffer); } case 'uint64': { if (int32Array.some((value) => value < 0)) { throw new Error('Can not convert int32 data to uin64 - negative value found.'); } const bigUint64Array = BigUint64Array.from(int32Array, BigInt); return new Uint8Array(bigUint64Array.buffer); } case 'int8': { if (int32Array.some((value) => value < -128 || value > 127)) { throw new Error('Can not convert int32 data to int8 - value out of range.'); } const int8Array = Int8Array.from(int32Array, Number); return new Uint8Array(int8Array.buffer); } case 'uint8': { if (int32Array.some((value) => value < 0 || value > 255)) { throw new Error('Can not convert int32 data to uint8 - value out of range.'); } return Uint8Array.from(int32Array, Number); } case 'uint32': { if (int32Array.some((value) => value < 0)) { throw new Error('Can not convert int32 data to uint32 - negative value found.'); } const uint32Array = Uint32Array.from(int32Array, Number); return new Uint8Array(uint32Array.buffer); } default: throw new Error(`Unsupported data conversion from 'int32' to ${dataType}`); } }; export type TensorId = number; /** * Manages TensorId to MLTensor mapping. */ export interface TensorManager { /** * Reserve a new TensorId. */ reserveTensorId(): TensorId; /** * Release a TensorId. */ releaseTensorId(tensorId: TensorId): void; /** * Ensure a MLTensor is created for the TensorId. */ ensureTensor( sessionId: number, tensorId: TensorId, dataType: MLOperandDataType, shape: readonly number[], copyOld: boolean, ): Promise; /** * Upload data to a MLTensor. */ upload(tensorId: TensorId, data: Uint8Array): void; /** * Download data from a MLTensor. */ download(tensorId: TensorId): Promise; download(tensorId: TensorId, dstTensor: ArrayBufferView | ArrayBuffer): Promise; /** * Release all tensors for a given session. */ releaseTensorsForSession(session: number): void; /** * Register an externally created MLTensor with a given session id and return a TensorId. */ registerTensor(sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; } let tensorGuid = 1; const createNewTensorId = (): TensorId => tensorGuid++; /** * Map from data type to fallback data type. * When the context does not support the original data type, use fallback data type as workaround. * Note: Currently, we only support fallback to int32 for certain integer data types. */ const webnnDataTypeToFallback = new Map([ ['int8', 'int32'], ['uint8', 'int32'], ['uint32', 'int32'], ['int64', 'int32'], ]); /** * Calculate the byte length of a tensor with the given data type and shape. */ const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number[]): number => { const dataTypeSize = webnnDataTypeToSize.get(dataType); if (!dataTypeSize) { throw new Error(`WebNN backend does not support data type: ${dataType}`); } return shape.length > 0 ? Math.ceil((shape.reduce((a, b) => a * b) * dataTypeSize) / 8) : 0; }; /** * TensorWrapper wraps an MLTensor and provides a way to track the last session that used it. */ class TensorWrapper { // The id of the last session that used this tensor. public sessionId: number; // This flag is used to indicate whether the data has been converted to fallback data type. public isDataConverted = false; private mlContext: MLContext; private mlTensor: MLTensor; private dataType: MLOperandDataType; // Fallback data type to use when the context does not support the original data type. private fallbackDataType: MLOperandDataType | undefined; private tensorShape: readonly number[]; constructor(descriptor: { sessionId: number; context: MLContext; tensor: MLTensor; dataType: MLOperandDataType; shape: readonly number[]; fallbackDataType?: MLOperandDataType; }) { const { sessionId, context, tensor, dataType, shape, fallbackDataType } = descriptor; this.sessionId = sessionId; this.mlContext = context; this.mlTensor = tensor; this.dataType = dataType; this.tensorShape = shape; this.fallbackDataType = fallbackDataType; } public get tensor(): MLTensor { return this.mlTensor; } public get type(): MLOperandDataType { return this.dataType; } public get fallbackType(): MLOperandDataType | undefined { return this.fallbackDataType; } public get shape(): readonly number[] { return this.tensorShape; } public get byteLength(): number { return calculateByteLength(this.dataType, this.tensorShape); } public destroy(): void { LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy'); this.mlTensor.destroy(); } public write(data: Uint8Array): void { this.mlContext.writeTensor(this.mlTensor, data); } public async read(): Promise; public async read(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise; public async read(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { if (this.fallbackDataType) { // This tensor has been fallback to int32 as workaround, we need to read it as its original integer data type. const data = await this.mlContext.readTensor(this.mlTensor); const originalData = convertInt32ToData(new Uint8Array(data), this.dataType); if (dstBuffer) { const targetBuffer = dstBuffer instanceof ArrayBuffer ? new Uint8Array(dstBuffer) : new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength); targetBuffer.set(originalData); return undefined; } else { return originalData.buffer; } } else { return dstBuffer ? this.mlContext.readTensor(this.mlTensor, dstBuffer) : this.mlContext.readTensor(this.mlTensor); } } public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean { return ( this.mlContext === context && this.dataType === dataType && this.tensorShape.length === shape.length && this.tensorShape.every((v, i) => v === shape[i]) ); } public setIsDataConverted(isConverted: boolean): void { this.isDataConverted = isConverted; } } /** * TensorTracker tracks the MLTensor and pending upload data. * * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until * we know the data type and shape. This is because WebNN only support creating MLTensors with dataTypes and shape. */ class TensorIdTracker { private activeUpload?: Uint8Array; constructor( private tensorManager: TensorManagerImpl, private wrapper?: TensorWrapper, ) {} public get tensorWrapper(): TensorWrapper | undefined { return this.wrapper; } public releaseTensor(): void { if (this.tensorWrapper) { this.tensorManager.releaseTensor(this.tensorWrapper); this.wrapper = undefined; } } public async ensureTensor( sessionId: number, dataType: MLOperandDataType, shape: readonly number[], copyOld: boolean, ): Promise { const context = this.tensorManager.getMLContext(sessionId); const opLimits = this.tensorManager.getMLOpSupportLimits(sessionId); let fallbackDataType: MLOperandDataType | undefined; // Check if the context supports the data type. If not, try to use the fallback data type. if (!opLimits?.input.dataTypes.includes(dataType)) { fallbackDataType = webnnDataTypeToFallback.get(dataType); if (!fallbackDataType || opLimits?.input.dataTypes.includes(fallbackDataType)) { throw new Error(`WebNN backend does not support data type: ${dataType}`); } LOG_DEBUG( 'verbose', () => `[WebNN] TensorIdTracker.ensureTensor: fallback dataType from ${dataType} to ${fallbackDataType}`, ); } if (this.wrapper) { if (this.wrapper.canReuseTensor(context, dataType, shape)) { return this.wrapper.tensor; } else { if (copyOld) { if (this.wrapper.byteLength !== calculateByteLength(dataType, shape)) { throw new Error('Unable to copy data to tensor with different size.'); } this.activeUpload = new Uint8Array(await this.wrapper.read()); } this.tensorManager.releaseTensor(this.wrapper); } } // eslint-disable-next-line no-bitwise const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE; this.wrapper = await this.tensorManager.getCachedTensor( sessionId, dataType, shape, usage, true, true, fallbackDataType, ); if (copyOld && this.activeUpload) { // We don't need to convert the original integer data to int32, // because it has been converted when it was uploaded. this.wrapper.write(this.activeUpload); this.activeUpload = undefined; } return this.wrapper.tensor; } public upload(data: Uint8Array): void { let newData = data; if (this.wrapper) { if (this.wrapper.fallbackType) { if (this.wrapper.fallbackType === 'int32') { // Convert original integer data to int32. newData = convertDataToInt32(data, this.wrapper.type); this.wrapper.setIsDataConverted(true); } else { throw new Error(`Unsupported fallback data type: ${this.wrapper.fallbackType}`); } } // Check if the data size matches the tensor size. if (data.byteLength === this.wrapper.byteLength) { // Write the newData to the tensor. this.wrapper.write(newData); return; } else { LOG_DEBUG('verbose', () => 'Data size does not match tensor size. Releasing tensor.'); this.releaseTensor(); } } if (this.activeUpload) { this.activeUpload.set(newData); } else { this.activeUpload = new Uint8Array(newData); } } public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { if (this.activeUpload) { // If this.activeUpload has been converted to int32, we need to convert it back to original integer data type. const dstData = this.wrapper?.isDataConverted ? convertInt32ToData(this.activeUpload, this.wrapper?.type) : this.activeUpload; if (dstBuffer) { if (dstBuffer instanceof ArrayBuffer) { new Uint8Array(dstBuffer).set(dstData); } else { new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(dstData); } return; } else { return dstData.buffer; } } if (!this.wrapper) { throw new Error('Tensor has not been created.'); } if (!dstBuffer) { return this.wrapper.read(); } return this.wrapper.read(dstBuffer); } } class TensorManagerImpl implements TensorManager { private tensorTrackersById: Map = new Map(); private freeTensors: TensorWrapper[] = []; private externalTensors: Set = new Set(); constructor(private backend: WebNNBackend) {} public getMLContext(sessionId: number): MLContext { const context = this.backend.getMLContext(sessionId); if (!context) { throw new Error('MLContext not found for session.'); } return context; } public getMLOpSupportLimits(sessionId: number): MLOpSupportLimits | undefined { return this.backend.getMLOpSupportLimits(sessionId); } public reserveTensorId(): TensorId { const tensorId = createNewTensorId(); this.tensorTrackersById.set(tensorId, new TensorIdTracker(this)); return tensorId; } public releaseTensorId(tensorId: TensorId): void { const tensorTracker = this.tensorTrackersById.get(tensorId); if (!tensorTracker) { return; } this.tensorTrackersById.delete(tensorId); if (tensorTracker.tensorWrapper) { this.releaseTensor(tensorTracker.tensorWrapper); } } public async ensureTensor( sessionId: number, tensorId: TensorId, dataType: MLOperandDataType, shape: number[], copyOld: boolean, ): Promise { LOG_DEBUG( 'verbose', () => `[WebNN] TensorManager.ensureTensor {tensorId: ${tensorId}, dataType: ${ dataType }, shape: ${shape}, copyOld: ${copyOld}}`, ); const tensor = this.tensorTrackersById.get(tensorId); if (!tensor) { throw new Error('Tensor not found.'); } return tensor.ensureTensor(sessionId, dataType, shape, copyOld); } public upload(tensorId: TensorId, data: Uint8Array): void { const tensor = this.tensorTrackersById.get(tensorId); if (!tensor) { throw new Error('Tensor not found.'); } tensor.upload(data); } public async download(tensorId: TensorId): Promise; public async download(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise; async download(tensorId: TensorId, dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { LOG_DEBUG( 'verbose', () => `[WebNN] TensorManager.download {tensorId: ${tensorId}, dstBuffer: ${dstBuffer?.byteLength}}`, ); const tensorTracker = this.tensorTrackersById.get(tensorId); if (!tensorTracker) { throw new Error('Tensor not found.'); } return tensorTracker.download(dstBuffer); } public releaseTensorsForSession(sessionId: number): void { for (const tensor of this.freeTensors) { if (tensor.sessionId === sessionId) { tensor.destroy(); } } this.freeTensors = this.freeTensors.filter((tensor) => tensor.sessionId !== sessionId); } public registerTensor( sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: readonly number[], ): TensorId { const context = this.getMLContext(sessionId); const tensorId = createNewTensorId(); // Defaulting to READ | WRITE if usage is not provided. const wrapper = new TensorWrapper({ sessionId, context, tensor: mlTensor, dataType, shape, }); this.tensorTrackersById.set(tensorId, new TensorIdTracker(this, wrapper)); this.externalTensors.add(wrapper); return tensorId; } /** * Get or create an MLTensor with the given data type and shape. */ public async getCachedTensor( sessionId: number, dataType: MLOperandDataType, shape: readonly number[], usage: MLTensorUsageFlags | undefined, writable: boolean, readable: boolean, fallbackDataType?: MLOperandDataType, ): Promise { const context = this.getMLContext(sessionId); for (const [index, tensor] of this.freeTensors.entries()) { if (tensor.canReuseTensor(context, dataType, shape)) { LOG_DEBUG( 'verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, ${ fallbackDataType ? `fallbackDataType: ${fallbackDataType},` : '' } shape: ${shape}`, ); const wrapper = this.freeTensors.splice(index, 1)[0]; wrapper.sessionId = sessionId; return wrapper; } } LOG_DEBUG( 'verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, ${ fallbackDataType ? `fallbackDataType: ${fallbackDataType},` : '' } shape: ${shape}}`, ); const tensor = await context.createTensor({ dataType: fallbackDataType ?? dataType, // If fallback data type is provided, use it. shape, dimensions: shape, usage, writable, readable, }); return new TensorWrapper({ sessionId, context, tensor, dataType, shape, fallbackDataType }); } /** * Release tensor for reuse unless external. */ public releaseTensor(tensorWrapper: TensorWrapper) { if (this.externalTensors.has(tensorWrapper)) { this.externalTensors.delete(tensorWrapper); } this.freeTensors.push(tensorWrapper); } } export const createTensorManager = (...args: ConstructorParameters): TensorManager => new TensorManagerImpl(...args);