import * as tf from '../../dist/tfjs.esm'; import { NetInput, TNetInput, toNetInput } from '../dom/index'; import { FaceFeatureExtractor } from '../faceFeatureExtractor/FaceFeatureExtractor'; import { FaceFeatureExtractorParams } from '../faceFeatureExtractor/types'; import { FaceProcessor } from '../faceProcessor/FaceProcessor'; import { FaceExpressions } from './FaceExpressions'; export class FaceExpressionNet extends FaceProcessor { constructor(faceFeatureExtractor: FaceFeatureExtractor = new FaceFeatureExtractor()) { super('FaceExpressionNet', faceFeatureExtractor); } public forwardInput(input: NetInput | tf.Tensor4D): tf.Tensor2D { return tf.tidy(() => tf.softmax(this.runNet(input))); } public async forward(input: TNetInput): Promise { return this.forwardInput(await toNetInput(input)); } public async predictExpressions(input: TNetInput) { const netInput = await toNetInput(input); const out = await this.forwardInput(netInput); const probabilitesByBatch = await Promise.all(tf.unstack(out).map(async (t) => { const data = t.dataSync(); t.dispose(); return data; })); out.dispose(); const predictionsByBatch = probabilitesByBatch .map((probabilites) => new FaceExpressions(probabilites as Float32Array)); return netInput.isBatchInput ? predictionsByBatch : predictionsByBatch[0]; } protected getDefaultModelName(): string { return 'face_expression_model'; } protected getClassifierChannelsIn(): number { return 256; } protected getClassifierChannelsOut(): number { return 7; } }