import type { Dtype } from "./dtype"; import { Module, ModuleList, Sequential } from "./nn_module"; import type { Tensor } from "."; export interface UNetModelConfig { /** channels of input tensor */ inChannels: number; /** base channels in model */ modelChannels: number; /** channels of output tensor */ outChannels: number; /** number of residual blocks per down/up sampling stage */ numResBlocks?: number; /** * a collection of downsample rates at which attention is applied. * For example, if this contains 4, then at 4x downsampling, attention is applied. */ attentionResolutions?: number[]; /** the dropout probability */ dropout?: number; /** channel multiples per down/up sampling stage */ channelMult?: number[]; /** if `true` use learnable convolutional upsampling/downsampling */ convResample?: boolean; /** determines whether to use 1D, 2D, or 3D convolutions */ dims?: number; /** if specified, then this model will be class-conditioned with `numClasses` classes */ numClasses?: number; useCheckpoint?: boolean; dtype?: Dtype; /** the number of attention heads in each attention layer */ numHeads?: number; /** if specified, ignore numHeads and instead use this number of channels in each attention head */ numHeadChannels?: number; numHeadsUpSample?: number; /** use a FiLM-like conditioning mechanism */ useScaleShiftNorm?: boolean; /** use residual blocks for up/down sampling */ resblockUpdown?: boolean; useNewAttentionOrder?: boolean; useSpatialTransformer?: boolean; transformerDepth?: number; contextDim?: number; } /** * Full UNet model with attention and timestep embedding. * * ** This is still a work in progress ** */ export declare class UNetModel extends Module { inChannels: number; modelChannels: number; outChannels: number; numResBlocks: number; attentionResolutions: number[]; dropout: number; channelMult: number[]; convResample: boolean; numClasses: number | null; useCheckpoint: boolean; dtype: Dtype; numHeadChannels: number; private _featureSize; timeEmbed: Sequential; inputBlocks: ModuleList; middleBlock: TimestepEmbedSequential; outputBlocks: ModuleList; out: Sequential; /** * Constructs a new UNet model with a given number of input and output channels along with * a variety of other options. */ constructor(config: UNetModelConfig); /** * Apply the model to the input batch. * @param x a [B, C, ...] Tensor of inputs. * @param timesteps a 1-D batch of timesteps. * @param context conditioning from cross-attention. * @param y a [B] Tensor of labels, if cross-conditional. * @returns a [B, C, ...] Tensor of outputs. */ forward(x: Tensor, timesteps: Tensor, context?: Tensor, y?: Tensor): Tensor; } declare class TimestepBlock extends Module { } /** * A sequential module that passes timestep embeddings to the children that * support it as an extra input. */ export declare class TimestepEmbedSequential extends TimestepBlock { constructor(...modules: Module[]); forward(x: Tensor, emb: Tensor, context?: Tensor): Tensor; } export {};