// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import {InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; import {copyFromExternalBuffer} from './wasm-core-impl'; import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, lazyResetGrad, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { private sessionId: number; private checkpointId: number; inputNames: string[]; outputNames: string[]; evalInputNames: string[] = []; evalOutputNames: string[] = []; async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { let buffer: Uint8Array; if (typeof uriOrBuffer === 'string') { const response = await fetch(uriOrBuffer); const arrayBuffer = await response.arrayBuffer(); buffer = new Uint8Array(arrayBuffer); } else { buffer = uriOrBuffer; } return copyFromExternalBuffer(buffer); } async createTrainingSession( checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions) { const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); // 0 is supposed to be the nullptr let evalModelData: SerializableInternalBuffer = [0, 0]; let optimizerModelData: SerializableInternalBuffer = [0, 0]; if (evalModelUriOrBuffer !== '') { evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); } if (optimizerModelUriOrBuffer !== '') { optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); } this.checkpointId = createCheckpointHandle(checkpointData); this.sessionId = createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); if (evalModelUriOrBuffer !== '') { [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); } } /** * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the * corresponding name as a number referring to the index in the list of names provided. * * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType * @param names either inputNames or outputNames * @returns a tuple of a list of values and a list of indices. */ convertMapIntoValuesArrayAndIndicesArray( feeds: {[name: string]: T}, names: string[], mapFunc: (val: T, index: number) => U): [T[], number[], U[]] { const values: T[] = []; const indices: number[] = []; Object.entries(feeds).forEach(kvp => { const name = kvp[0]; const tensor = kvp[1]; const index = names.indexOf(name); if (index === -1) { throw new Error(`invalid input '${name}`); } values.push(tensor); indices.push(index); }); const uList = values.map(mapFunc); return [values, indices, uList]; } /** * Helper method that converts the TensorMetadata that the wasm-core functions return to the * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the * corresponding result. * * @param results used to populate the resultMap if there is no value for that outputName already * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. * @returns a map of output names and OnnxValues. */ convertTensorMetadataToReturnType( results: TensorMetadata[], outputArray: Array, outputIndices: number[]): SessionHandler.ReturnType { const resultMap: SessionHandler.ReturnType = {}; for (let i = 0; i < results.length; i++) { resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } return resultMap; } async lazyResetGrad(): Promise { await lazyResetGrad(this.sessionId); } async runTrainStep( feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( feeds, this.inputNames, (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray( fetches, this.outputNames, (t, i): TensorMetadata|null => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); } async runOptimizerStep(options: InferenceSession.RunOptions): Promise { await runOptimizerStep(this.sessionId, options); } async runEvalStep( feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( feeds, this.evalInputNames, (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`)); const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray( fetches, this.evalOutputNames, (t, i): TensorMetadata|null => t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null); const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); } async getParametersSize(trainableOnly: boolean): Promise { return getParametersSize(this.sessionId, trainableOnly); } async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { await loadParametersBuffer(this.sessionId, array, trainableOnly); } async getContiguousParameters(trainableOnly: boolean): Promise { const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); return decodeTensorMetadata(tensorResult); } async dispose(): Promise { return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); } }