import * as tf from '../../dist/tfjs.esm'; import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index'; import { NetParams } from './types'; export function extractParamsFromWeightMap( weightMap: tf.NamedTensorMap, ): { params: NetParams, paramMappings: ParamMapping[] } { const paramMappings: ParamMapping[] = []; const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings); function extractFcParams(prefix: string): FCParams { const weights = extractWeightEntry(`${prefix}/weights`, 2); const bias = extractWeightEntry(`${prefix}/bias`, 1); return { weights, bias }; } const params = { fc: { age: extractFcParams('fc/age'), gender: extractFcParams('fc/gender'), }, }; disposeUnusedWeightTensors(weightMap, paramMappings); return { params, paramMappings }; }