/** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ import * as tf from '@tensorflow/tfjs'; import type { CustomCallbackArgs } from '@tensorflow/tfjs'; import { CustomHandPose, type Metadata } from './custom-handpose'; export interface TrainingParameters { denseUnits: number; epochs: number; learningRate: number; batchSize: number; } export declare class TeachableHandPose extends CustomHandPose { private trainDataset; private validationDataset; private __stopTrainingResolve?; /** Array of training examples per class, each element is a class bucket */ examples: Float32Array[][]; private seed?; /** True once the model has been trained (more than 2 layers built) */ get isTrained(): boolean; get isPrepared(): boolean; get numClasses(): number; /** * Add a sample of data under the provided class index. * @param className Integer index for the class this example belongs to * @param sample Feature vector from `handOutputsToArray()` */ addExample(className: number, sample: Float32Array): void; /** * Classify a hand feature vector using the trained model. * Returns full probability distribution across all classes. */ predict(handOutput: Float32Array): Promise; /** * Classify a hand feature vector and return the top-K predictions. */ predictTopK(handOutput: Float32Array, maxPredictions?: number): Promise; /** * Pre-process collected examples into train / validation tf.data.Datasets. * Must be called before `train()` or will be called automatically. */ prepare(): void; private convertToTfDataset; /** * Build and train the Dense classifier head on the collected examples. * * Architecture: * Dense(denseUnits, relu) → BatchNorm → Dropout(0.3) * → Dense(denseUnits/2, relu) → BatchNorm → Dropout(0.2) * → Dense(numClasses, softmax) * * @param params Training hyper-parameters * @param callbacks Keras-compatible training callbacks */ train(params: TrainingParameters, callbacks?: CustomCallbackArgs): Promise; /** * Save the trained model to a given IO handler or URL. */ save(handlerOrURL: tf.io.IOHandler | string, config?: tf.io.SaveConfig): Promise; /** Initialise the examples array (one bucket per class). Must be called after setLabels(). */ prepareDataset(): void; stopTraining(): Promise; dispose(): void; setLabel(index: number, label: string): void; setLabels(labels: string[]): void; getLabel(index: number): string; getLabels(): string[]; setName(name: string): void; getName(): string | undefined; /** Optional seed for reproducible data shuffling. */ setSeed(seed: string): void; /** * Calculate per-class accuracy on the held-out validation set. * Returns tensors with predicted and ground-truth class indices. */ calculateAccuracyPerClass(): Promise<{ reference: tf.Tensor; predictions: tf.Tensor; }>; } /** * Create a new TeachableHandPose instance ready for training. * @param metadata Partial metadata (labels can be added later via setLabels) */ export declare function createTeachable(metadata: Partial): Promise;