// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common'; import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( checkpointStateUriOrBuffer: string | Uint8Array, trainModelUriOrBuffer: string | Uint8Array, evalModelUriOrBuffer: string | Uint8Array, optimizerModelUriOrBuffer: string | Uint8Array, options: InferenceSession.SessionOptions, ): Promise { const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); await handler.createTrainingSession( checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options, ); return Promise.resolve(handler); } } export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();