import { LitElementWw } from '@webwriter/lit' import { TemplateResult, html } from 'lit' import { customElement, queryAll, property } from 'lit/decorators.js' import { consume } from '@lit/context' import type { CLayerConf } from '@/types/c_layer_conf' import type { TensorConf } from '@/types/tensor_conf' import type { CLayerConnectionConf } from '@/types/c_layer_connection_conf' import { layerConfsContext } from '@/contexts/layer_confs_context' import { layerConnectionConfsContext } from '@/contexts/layer_con_confs_context' import { CLayer } from '@/components/network/c_layer' import { InputLayer } from '@/components/network/input_layer' import { OutputLayer } from '@/components/network/output_layer' import { CLayerConnection } from '@/components/network/c_layer_connection' import { DenseLayer } from '@/components/network/dense_layer' import '@/components/network/input_layer' import '@/components/network/dense_layer' import '@/components/network/output_layer' import '@/components/network/c_layer' import '@/components/network/c_layer_connection' import { CEdge } from './c_edge' import { CNeuron } from './neuron' export class CNetwork extends LitElementWw { static scopedElements = { "c-edge": CEdge, "c-layer-connection": CLayerConnection, "dense-layer": DenseLayer, "input-layer": InputLayer, "output-layer": OutputLayer, "c-neuron": CNeuron } @consume({ context: layerConfsContext, subscribe: true }) accessor layerConfs: CLayerConf[] @consume({ context: layerConnectionConfsContext, subscribe: true }) accessor layerConnectionConfs: CLayerConnectionConf[] @property() accessor tensorConfs: Map = new Map() @queryAll('.layer') accessor _layers: NodeListOf @queryAll('c-layer-connection') accessor _layerConnections: NodeListOf // METHODS - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // -> GETTING - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // get a layer by its layerId getLayerById(layerId: number): CLayer { return Array.from(this._layers).find( (layer) => layer.conf.layerId == layerId ) } // get the input layers getInputLayers(): InputLayer[] { return ( this.layerConfs .map((layerConf) => this.getLayerById(layerConf.layerId)) .filter((layer) => layer instanceof InputLayer) ) } // get the output layer getOutputLayer(): OutputLayer { return ( this.layerConfs .map((layerConf) => this.getLayerById(layerConf.layerId)) .find((layer) => layer instanceof OutputLayer) ) } // get a layer connection by its id getLayerConnectionByLayerIds( sourceLayerId: number, targetLayerId: number ): CLayerConnection { return Array.from(this._layerConnections).find((layerConnection) => { return ( layerConnection.conf.sourceLayerId == sourceLayerId && layerConnection.conf.targetLayerId == targetLayerId ) }) } // get the target (subsequent) layers of a layer getTargetsFor(source: CLayer): CLayer[] { return this.layerConnectionConfs .filter((conConf) => conConf.sourceLayerId == source.conf.layerId) .map((conConf) => this.getLayerById(conConf.targetLayerId)) } // tet the source (preceding) layers of a layer getSourcesFor(target: CLayer): CLayer[] { return this.layerConnectionConfs .filter((conConf) => conConf.targetLayerId == target.conf.layerId) .map((conConf) => this.getLayerById(conConf.sourceLayerId)) } // RENDER - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - getHTMLForLayerConf(layerConf: CLayerConf) { const layer = (this.shadowRoot as any).createElement(layerConf.HTML_TAG) layer.conf = layerConf const tensorConf = this.tensorConfs?.get(layerConf.layerId) layer.tensor = tensorConf?.tensor layer.bias = tensorConf?.bias layer.weights = tensorConf?.weights layer.classList.add('layer') return layer } getHTMLForLayerConnectionConf(layerConnectionConf: CLayerConnectionConf) { const layerConnection = ( (this.shadowRoot as any).createElement('c-layer-connection') ) layerConnection.conf = layerConnectionConf return layerConnection } render(): TemplateResult<1> { return html`
${this.layerConfs.map((layerConf) => this.getHTMLForLayerConf(layerConf) )}
${this.layerConnectionConfs.map((layerConnectionConf) => this.getHTMLForLayerConnectionConf(layerConnectionConf) )}
` } }