import { Conv2d, GroupNorm, type Layer, LayerNorm, Linear, Tensor } from '@jsgrad/jsgrad/base'; declare class ResBlock { in_layers: Layer[]; emb_layers: Layer[]; out_layers: Layer[]; skip_connection?: Conv2d; constructor(channels: number, emb_channels: number, out_channels: number); call: (x: Tensor, emb: Tensor) => Tensor; } declare class CrossAttention { n_heads: number; d_head: number; to_q: Linear; to_k: Linear; to_v: Linear; to_out: [Linear]; constructor(query_dim: number, ctx_dim: number, n_heads: number, d_head: number); call: (x: Tensor, ctx?: Tensor) => Tensor; } declare class FeedForward { net: Layer[]; constructor(dim: number, mult?: number); call: (x: Tensor) => Tensor; } declare class BasicTransformerBlock { attn1: CrossAttention; ff: FeedForward; attn2: CrossAttention; norm1: LayerNorm; norm2: LayerNorm; norm3: LayerNorm; constructor(dim: number, ctx_dim: number, n_heads: number, d_head: number); call: (x: Tensor, ctx?: Tensor) => Tensor; } declare class SpatialTransformer { use_linear: boolean; norm: GroupNorm; proj_in: Linear | Conv2d; transformer_blocks: BasicTransformerBlock[]; proj_out: Linear | Conv2d; constructor(channels: number, n_heads: number, d_head: number, ctx_dim: number | number[], use_linear: boolean, depth?: number); call: (x: Tensor, ctx?: Tensor) => Tensor; } declare class Downsample { op: Conv2d; constructor(channels: number); call: (x: Tensor) => Tensor; } declare class Upsample { conv: Conv2d; constructor(channels: number); call: (x: Tensor) => Tensor; } type BB = ResBlock | SpatialTransformer | Conv2d | Downsample | Upsample; export declare class UNetModel { model_ch: number; attention_resolutions: number[]; d_head?: number | undefined; n_heads?: number | undefined; num_res_blocks: number[]; time_embed: Layer[]; label_emb?: Layer[][]; input_blocks: BB[][]; middle_block: BB[]; output_blocks: BB[][]; out: Layer[]; constructor(adm_in_ch: number | undefined, in_ch: number, out_ch: number, model_ch: number, attention_resolutions: number[], num_res_blocks: number, channel_mult: number[], transformer_depth: number[], ctx_dim: number | number[], use_linear?: boolean, d_head?: number | undefined, n_heads?: number | undefined); call: (x: Tensor, tms: Tensor, ctx: Tensor, y?: Tensor) => Tensor; } export {};