import type { ReactiveController } from 'lit' import type { NeuralNetwork } from '@/app' import type { ModelConf } from '@/types/model_conf' import { ModelUtils } from '@/utils/model_utils' import { DataSetUtils } from '@/utils/data_set_utils' import { AlertUtils } from '@/utils/alert_utils' import * as tf from '@tensorflow/tfjs' import * as tfvis from '@tensorflow/tfjs-vis' export class ModelController implements ReactiveController { host: NeuralNetwork constructor(host: NeuralNetwork) { this.host = host host.addController(this) } // HOST LIFECYCLE - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - hostConnected() { // add event listeners for model related events on host this.host.renderRoot.addEventListener( 'set-train-option', ( e: CustomEvent<{ option: string value: string }> ) => this.setTrainOption(e.detail.option, e.detail.value) ) this.host.renderRoot.addEventListener('discard-model', (_e: Event) => this.discardModel() ) this.host.renderRoot.addEventListener( 'train-model', (e: CustomEvent) => { setTimeout(() => { this.trainModel(e.detail) }, 0) } ) this.host.renderRoot.addEventListener( 'predict-model', (e: CustomEvent>) => this.predictModel(e.detail) ) this.host.renderRoot.addEventListener('delete-prediction', (_e: Event) => this.deletePrediction() ) } // METHODS - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // the container where the metrics are rendered into informs the host about // itself when connected. Then, here, a reference to this container is stored // to render the metrics into it. setTrainMetricsContainer(container: HTMLDivElement) { this.host.trainMetricsContainer = container } // changes a (hyper-)parameter for the training setTrainOption(option: string, value: string) { this.host.trainOptions[option] = value this.host.trainOptions = { ...this.host.trainOptions, } } // discards the current model by deleting the model, deleting the metrics and // informing the network so it can also respond to it (remove model // references) discardModel(): void { this.host.modelConf = ( JSON.parse(JSON.stringify(ModelUtils.defaultModelConf)) ) // empty the container for the metrics. if we did not do this, it would // also show the metrics from the previous training if (this.host.trainMetricsContainer) { this.host.trainMetricsContainer.innerHTML = '' } // remove model references (like tensor and weights) in the network if (this.host.network) { this.host.network.tensorConfs = new Map() this.host.networkController.updateLayerConfs() } } // tells the network to build a model and compiles it - usually this method is // called when training is started and there is no model yet. buildModel(): void { this.discardModel() const model = this.host.networkController.buildModel() if (model && this.host.dataSet) { this.host.modelConf.plottedMetrics.push('loss') this.host.modelConf.plottedMetrics.push('val_loss') if (this.host.dataSet.type == 'regression') { this.host.modelConf.loss = 'meanSquaredError' } else if (this.host.dataSet.type == 'classification') { this.host.modelConf.loss = 'categoricalCrossentropy' this.host.modelConf.metrics.push('acc') this.host.modelConf.plottedMetrics.push('acc') this.host.modelConf.plottedMetrics.push('val_acc') } this.host.modelConf = { ...this.host.modelConf } const optimizer = tf.train.adam( parseFloat(this.host.trainOptions.learningRate) ) model.compile({ optimizer, loss: this.host.modelConf.loss, metrics: this.host.modelConf.metrics, }) this.host.modelConf.model = model AlertUtils.spawn({ message: `The model was successfully compiled! All hyperparameter and network architecture changes were taken into account!`, variant: 'success', icon: 'check-circle', }) this.host.modelConf = { ...this.host.modelConf } } } // trains the model for a specific number of epochs by performing the training // itself but also handing training data to the network during the training to // visualize it and update the metrics trainModel(epochs: number): void { // first build the model if it does not exist if (!this.host.modelConf.model) { this.buildModel() } // now we should have a model and can start training if (this.host.modelConf.model) { // add the number of epochs we want to train to the total epoch count this.host.modelConf.totalEpochs += epochs // set training state this.host.modelConf.isTraining = true // save the changes this.host.modelConf = { ...this.host.modelConf } // inputs const inputs: tf.Tensor[] = [] for (const inputLayer of this.host.network.getInputLayers()) { const inputData: number[][] = [] inputData.push( ...DataSetUtils.getFeatureDataByKeys( this.host.dataSet, inputLayer.conf.featureKeys ) ) inputs.push(tf.tensor(inputData)) } // label tensor depends on regression vs classification type const labelData: number[] = DataSetUtils.getLabelData(this.host.dataSet) let labels: tf.Tensor if (this.host.dataSet.type == 'regression') { labels = tf.tensor(labelData) } else if ( this.host.dataSet.type == 'classification' && this.host.dataSet.labelDesc.classes ) { labels = tf.oneHot( tf.tensor(labelData, undefined, 'int32'), this.host.dataSet.labelDesc.classes.length ) } else { return } void tfvis.show.history( this.host.trainMetricsContainer, this.host.modelConf.history, ['loss', 'val_loss', 'val_acc', ...this.host.modelConf.metrics], { height: 100, xLabel: 'Epoch', } ) /* void inputs[0].data().then((data) => { console.log(input[0].shape) console.log(data) }) void labels.data().then((data) => { console.log() console.log(data) }) */ // start the training itself void this.host.modelConf.model .fit(inputs, labels, { epochs: this.host.modelConf.totalEpochs, batchSize: parseInt(this.host.trainOptions.batchSize), validationSplit: 0.1, callbacks: [ { onBatchEnd: (batch: number, _logs: tf.Logs) => { // update the act batch var (for displaying purposes) this.host.modelConf.actBatch = batch + 1 // update the weights to be displayed in the neurons this.host.networkController.updateWeights( this.host.modelConf.model.getWeights() ) // update the model to reflect all changes this.host.modelConf = { ...this.host.modelConf, } }, onEpochEnd: (epoch: number, logs: tf.Logs) => { this.host.modelConf.actEpoch = epoch + 1 this.host.modelConf.history.push(logs) this.host.modelConf = { ...this.host.modelConf, } void tfvis.show.history( this.host.trainMetricsContainer, this.host.modelConf.history, [ 'loss', 'val_loss', 'val_acc', ...this.host.modelConf.metrics, ], { height: 100, xLabel: 'Epoch', } ) }, }, ], initialEpoch: this.host.modelConf.actEpoch, }) .then((info) => { console.log(info) }) .catch((err) => { console.error(err) }) .finally(() => { this.host.modelConf.isTraining = false this.host.modelConf = { ...this.host.modelConf, } }) } } // predict a label based on a input object predictModel(inputObject: Record): void { // inputs const inputs: tf.Tensor[] = [] for (const inputLayer of this.host.network.getInputLayers()) { const inputData: number[] = [] Object.keys(inputObject).forEach((featureKey) => { if (inputLayer.conf.featureKeys.includes(featureKey)) { inputData.push(inputObject[featureKey]) } }) inputs.push(tf.tensor([inputData])) } const predictedTensor = this.host.modelConf.model.predict(inputs) const predictedArray = predictedTensor.arraySync() console.log('PREDICTION INPUTS:') console.log(inputObject) console.log('PREDICTED ARRAY:') console.log(predictedArray) this.host.modelConf.predictedValue = predictedArray[0] this.host.modelConf = { ...this.host.modelConf } console.log(this.host.modelConf.predictedValue) } // deletes the current prediction - this is used to be able to make a new // prediction request deletePrediction(): void { this.host.modelConf.predictedValue = undefined this.host.modelConf = { ...this.host.modelConf } } }