import { isTensor } from '../utils/index'; import { ParamMapping } from './types'; export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) { return (originalPath: string, paramRank: number, mappedPath?: string) => { const tensor = weightMap[originalPath]; if (!isTensor(tensor, paramRank)) { throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`); } paramMappings.push( { originalPath, paramPath: mappedPath || originalPath }, ); return tensor; }; }