// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import {InferenceSession, Tensor} from 'onnxruntime-common'; import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {prepareInputOutputTensor} from './wasm-core-impl'; import {getInstance} from './wasm-factory'; import {checkLastError} from './wasm-utils'; const NO_TRAIN_FUNCS_MSG = 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' + 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; /** * Runs the checkLastError function which will throw an error, if the provided error code matches the specified * pattern for an error code. * @param errCode number to evaluated for if it's an error * @param message message to pass into checkLastError * @param checkNeqZero when true, treats not equal to zero as an error. * When false, treats equal to zero as an error. */ const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { if (checkNeqZero && errCode !== 0) { checkLastError(message); } else if (!checkNeqZero && errCode === 0) { checkLastError(message); } }; export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => { const wasm = getInstance(); const [checkpointDataOffset, checkpointDataLength] = checkpointData; let checkpointHandle = 0; try { if (wasm._OrtTrainingLoadCheckpoint) { checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); return checkpointHandle; } catch (e) { if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); } throw e; } finally { // free buffer from wasm heap wasm._OrtFree(checkpointData[0]); } }; const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { const wasm = getInstance(); const stack = wasm.stackSave(); try { const dataOffset = wasm.stackAlloc(8); if (wasm._OrtTrainingGetModelInputOutputCount) { const errorCode = wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); } } finally { wasm.stackRestore(stack); } }; const getModelInputOutputNamesLoop = (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => { const names = []; const wasm = getInstance(); for (let i = 0; i < count; i++) { if (wasm._OrtTrainingGetModelInputOutputName) { const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); names.push(wasm.UTF8ToString(name)); wasm._free(name); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } } return names; }; export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { let inputNames: string[] = []; let outputNames: string[] = []; const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); return [inputNames, outputNames]; }; export const createTrainingSessionHandle = (checkpointHandle: number, trainModelData: SerializableInternalBuffer, evalModelData: SerializableInternalBuffer, optimizerModelData: SerializableInternalBuffer, options: InferenceSession.SessionOptions): number => { const wasm = getInstance(); let trainingSessionHandle = 0; let sessionOptionsHandle = 0; let allocs: number[] = []; try { [sessionOptionsHandle, allocs] = setSessionOptions(options); if (wasm._OrtTrainingCreateSession) { trainingSessionHandle = wasm._OrtTrainingCreateSession( sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], evalModelData[1], optimizerModelData[0], optimizerModelData[1]); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); return trainingSessionHandle; } catch (e) { if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { wasm._OrtTrainingReleaseSession(trainingSessionHandle); } throw e; } finally { wasm._free(trainModelData[0]); wasm._free(evalModelData[0]); wasm._free(optimizerModelData[0]); if (sessionOptionsHandle !== 0) { wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } allocs.forEach(alloc => wasm._free(alloc)); } }; /** * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the * WASM tensors. * * @param trainingSessionId * @param indices for each tensor, the index of the input or output name that the tensor corresponds with * @param tensors list of TensorMetaData * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting * handles of the allocated tensors on the heap * @param inputOutputAllocs modified in-place by this method * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor */ const createAndAllocateTensors = (trainingSessionId: number, indices: number[], tensors: Array, tensorHandles: number[], inputOutputAllocs: number[], indexAdd: number) => { const count = indices.length; // creates the tensors for (let i = 0; i < count; i++) { prepareInputOutputTensor( tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); } // moves to heap const wasm = getInstance(); const valuesOffset = wasm.stackAlloc(count * 4); let valuesIndex = valuesOffset / 4; for (let i = 0; i < count; i++) { wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; } return valuesOffset; }; /** * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information * associated with the tensor handle. * * @param outputValuesOffset * @param outputCount * @returns list of TensorMetadata retrieved from the output handles. */ const moveOutputToTensorMetadataArr = (outputValuesOffset: number, outputCount: number, outputTensorHandles: number[], outputTensors: Array) => { const wasm = getInstance(); const output: TensorMetadata[] = []; for (let i = 0; i < outputCount; i++) { const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; if (tensor === outputTensorHandles[i]) { // output tensor is pre-allocated. no need to copy data. output.push(outputTensors[i]!); continue; } const beforeGetTensorDataStack = wasm.stackSave(); // stack allocate 4 pointer value const tensorDataOffset = wasm.stackAlloc(4 * 4); let type: Tensor.Type|undefined, dataOffset = 0; try { const errorCode = wasm._OrtGetTensorData( tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); let tensorDataIndex = tensorDataOffset / 4; const dataType = wasm.HEAPU32[tensorDataIndex++]; dataOffset = wasm.HEAPU32[tensorDataIndex++]; const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; const dimsLength = wasm.HEAPU32[tensorDataIndex++]; const dims = []; for (let i = 0; i < dimsLength; i++) { dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); } wasm._OrtFree(dimsOffset); const size = dims.reduce((a, b) => a * b, 1); type = tensorDataTypeEnumToString(dataType); if (type === 'string') { const stringData: string[] = []; let dataIndex = dataOffset / 4; for (let i = 0; i < size; i++) { const offset = wasm.HEAPU32[dataIndex++]; const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } output.push([type, dims, stringData, 'cpu']); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); new Uint8Array(data.buffer, data.byteOffset, data.byteLength) .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); output.push([type, dims, data, 'cpu']); } } finally { wasm.stackRestore(beforeGetTensorDataStack); if (type === 'string' && dataOffset) { wasm._free(dataOffset); } wasm._OrtReleaseTensor(tensor); } } return output; }; export const lazyResetGrad = async(trainingSessionId: number): Promise => { const wasm = getInstance(); if (wasm._OrtTrainingLazyResetGrad) { const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId); ifErrCodeCheckLastError(errorCode, 'Can\'t call lazyResetGrad.'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } }; export const runTrainStep = async( trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], outputTensors: Array, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); const inputCount = inputIndices.length; const outputCount = outputIndices.length; let runOptionsHandle = 0; let runOptionsAllocs: number[] = []; const inputTensorHandles: number[] = []; const outputTensorHandles: number[] = []; const inputOutputAllocs: number[] = []; const beforeRunStack = wasm.stackSave(); try { // prepare parameters by moving them to heap [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // handle inputs -- you don't want anything added to the index const inputValuesOffset = createAndAllocateTensors( trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); // handle outputs // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor const outputValuesOffset = createAndAllocateTensors( trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); if (wasm._OrtTrainingRunTrainStep) { const errorCode = wasm._OrtTrainingRunTrainStep( trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); } finally { wasm.stackRestore(beforeRunStack); inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); inputOutputAllocs.forEach(p => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); } runOptionsAllocs.forEach(p => wasm._free(p)); } }; export const runOptimizerStep = async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); let runOptionsHandle = 0; let runOptionsAllocs: number[] = []; try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); if (wasm._OrtTrainingOptimizerStep) { const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } } finally { if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); } runOptionsAllocs.forEach(p => wasm._free(p)); } }; export const runEvalStep = async( trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], outputTensors: Array, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); const inputCount = inputIndices.length; const outputCount = outputIndices.length; let runOptionsHandle = 0; let runOptionsAllocs: number[] = []; const inputTensorHandles: number[] = []; const outputTensorHandles: number[] = []; const inputOutputAllocs: number[] = []; const beforeRunStack = wasm.stackSave(); try { // prepare parameters by moving them to heap [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // handle inputs -- you don't want anything added to the index const inputValuesOffset = createAndAllocateTensors( trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0); // handle outputs // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor const outputValuesOffset = createAndAllocateTensors( trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount); if (wasm._OrtTrainingEvalStep) { const errorCode = wasm._OrtTrainingEvalStep( trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle); ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); } finally { wasm.stackRestore(beforeRunStack); inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); inputOutputAllocs.forEach(p => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); } runOptionsAllocs.forEach(p => wasm._free(p)); } }; export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { const wasm = getInstance(); const stack = wasm.stackSave(); try { const sizeOffset = wasm.stackAlloc(4); if (wasm._OrtTrainingGetParametersSize) { const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size'); return wasm.HEAP32[sizeOffset / 4]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); } } finally { wasm.stackRestore(stack); } }; export const getContiguousParameters = async(trainingSessionId: number, trainableOnly: boolean): Promise => { const wasm = getInstance(); const stack = wasm.stackSave(); const tensorTypeAsString = 'float32'; const locationAsString = 'cpu'; const parametersSize = getParametersSize(trainingSessionId, trainableOnly); let tensor = 0; // allocates a buffer of the correct size on the WASM heap const paramsByteLength = 4 * parametersSize; const paramsOffset = wasm._malloc(paramsByteLength); // handles the dimensions-related createTensor parameters const dims = [parametersSize]; const dimsOffset = wasm.stackAlloc(4); const dimsIndex = dimsOffset / 4; wasm.HEAP32[dimsIndex] = parametersSize; try { // wraps allocated array in a tensor tensor = wasm._OrtCreateTensor( tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length, dataLocationStringToEnum(locationAsString)); ifErrCodeCheckLastError( tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false); if (wasm._OrtTrainingCopyParametersToBuffer) { const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); const data = new typedArrayConstructor(parametersSize); const output: TensorMetadata[] = []; new Uint8Array(data.buffer, data.byteOffset, data.byteLength) .set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength)); output.push([tensorTypeAsString, dims, data, locationAsString]); if (output.length !== 1) { throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of one, got ${output.length}`); } else { return output[0]; } } finally { if (tensor !== 0) { wasm._OrtReleaseTensor(tensor); } wasm._free(paramsOffset); wasm._free(dimsOffset); wasm.stackRestore(stack); } }; export const loadParametersBuffer = async(trainingSessionId: number, buffer: Uint8Array, trainableOnly: boolean): Promise => { const wasm = getInstance(); const stack = wasm.stackSave(); const tensorTypeAsString = 'float32'; const locationAsString = 'cpu'; // allocates & copies JavaScript buffer to WASM heap const bufferByteLength = buffer.length; const bufferCount = bufferByteLength / 4; const bufferOffset = wasm._malloc(bufferByteLength); wasm.HEAPU8.set(buffer, bufferOffset); // allocates and handles moving dimensions information to WASM memory const dimsOffset = wasm.stackAlloc(4); wasm.HEAP32[dimsOffset / 4] = bufferCount; const dimsLength = 1; let tensor = 0; try { tensor = wasm._OrtCreateTensor( tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength, dataLocationStringToEnum(locationAsString)); ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); if (wasm._OrtTrainingCopyParametersFromBuffer) { const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } } finally { if (tensor !== 0) { wasm._OrtReleaseTensor(tensor); } wasm.stackRestore(stack); wasm._free(bufferOffset); wasm._free(dimsOffset); } }; export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { const wasm = getInstance(); if (wasm._OrtTrainingReleaseSession) { wasm._OrtTrainingReleaseSession(sessionId); } if (wasm._OrtTrainingReleaseCheckpoint) { wasm._OrtTrainingReleaseCheckpoint(checkpointId); } };