import type { Buffer } from './device.ts'; import { DType } from './dtype.ts'; import { type ConstType } from './helpers/helpers.ts'; import { Enum, type Metadata, WeakValueMap } from './helpers/helpers.ts'; import type { Renderer } from './renderer/index.ts'; import { ShapeTracker } from './shape/shapetracker.ts'; export type Variable = UOp; export type ConstLike = ConstType | Variable | ConstType[]; declare class SimpleMathTrait> { alu: (_arg: Ops, ..._src: T[]) => T; const_like: (_b: ConstLike) => T; ufix: (x: ConstType) => T; _binop: (op: Ops, x: ConstType, reverse: boolean) => T; logical_not: () => T; neg: () => T; add: (x: ConstType, reverse?: boolean) => T; mul: (x: ConstType, reverse?: boolean) => T; bitwise_and: (x: ConstType, reverse?: boolean) => T; bitwise_or: (x: ConstType, reverse?: boolean) => T; xor: (x: ConstType, reverse?: boolean) => T; idiv: (x: ConstType, reverse?: boolean) => T; mod: (x: ConstType, reverse?: boolean) => T; sub: (x: ConstType, reverse?: boolean) => T; div: (x: ConstType, reverse?: boolean) => T; lt: (x: ConstType) => T; gt: (x: ConstType) => T; ge: (x: ConstType) => T; le: (x: ConstType) => T; ne: (x: ConstType) => T; eq: (x: ConstType) => T; } export declare class MathTrait> extends SimpleMathTrait { lshift: (x: ConstType, reverse?: boolean) => T; rshift: (x: ConstType, reverse?: boolean) => T; maximum: (x: ConstType) => T; minimum: (x: ConstType) => any; where: (x: ConstType, y: ConstType) => T; threefry: (seed: ConstType) => T; reciprocal: () => T; sqrt: () => T; sin: () => T; log2: () => T; exp2: () => T; } export declare class Ops extends Enum { private static VALUES; static values: () => Ops[]; _id: bigint; constructor(name: Name); static readonly SINK: Ops<"SINK">; static readonly CONTIGUOUS: Ops<"CONTIGUOUS">; static readonly CONTIGUOUS_BACKWARD: Ops<"CONTIGUOUS_BACKWARD">; static readonly DETACH: Ops<"DETACH">; static readonly PRELOAD: Ops<"PRELOAD">; static readonly EMPTY: Ops<"EMPTY">; static readonly COPY: Ops<"COPY">; static readonly BUFFER_VIEW: Ops<"BUFFER_VIEW">; static readonly BLOCK: Ops<"BLOCK">; static readonly BLOCKSTART: Ops<"BLOCKSTART">; static readonly BLOCKFORK: Ops<"BLOCKFORK">; static readonly BLOCKEND: Ops<"BLOCKEND">; static readonly RESHAPE: Ops<"RESHAPE">; static readonly PERMUTE: Ops<"PERMUTE">; static readonly EXPAND: Ops<"EXPAND">; static readonly PAD: Ops<"PAD">; static readonly SHRINK: Ops<"SHRINK">; static readonly STRIDE: Ops<"STRIDE">; static readonly UNROLL: Ops<"UNROLL">; static readonly CONTRACT: Ops<"CONTRACT">; static readonly VIEW: Ops<"VIEW">; static readonly DEFINE_GLOBAL: Ops<"DEFINE_GLOBAL">; static readonly BUFFER: Ops<"BUFFER">; static readonly DEFINE_VAR: Ops<"DEFINE_VAR">; static readonly DEFINE_LOCAL: Ops<"DEFINE_LOCAL">; static readonly DEFINE_ACC: Ops<"DEFINE_ACC">; static readonly VALID: Ops<"VALID">; static readonly SPECIAL: Ops<"SPECIAL">; static readonly NOOP: Ops<"NOOP">; static readonly REDUCE_AXIS: Ops<"REDUCE_AXIS">; static readonly GEP: Ops<"GEP">; static readonly VECTORIZE: Ops<"VECTORIZE">; static readonly CAST: Ops<"CAST">; static readonly BITCAST: Ops<"BITCAST">; static readonly EXP2: Ops<"EXP2">; static readonly LOG2: Ops<"LOG2">; static readonly SIN: Ops<"SIN">; static readonly SQRT: Ops<"SQRT">; static readonly RECIP: Ops<"RECIP">; static readonly NEG: Ops<"NEG">; static readonly LOAD: Ops<"LOAD">; static readonly STORE: Ops<"STORE">; static readonly INDEX: Ops<"INDEX">; static readonly WMMA: Ops<"WMMA">; static readonly ADD: Ops<"ADD">; static readonly MUL: Ops<"MUL">; static readonly IDIV: Ops<"IDIV">; static readonly MAX: Ops<"MAX">; static readonly MOD: Ops<"MOD">; static readonly CMPLT: Ops<"CMPLT">; static readonly CMPNE: Ops<"CMPNE">; static readonly XOR: Ops<"XOR">; static readonly SHL: Ops<"SHL">; static readonly SHR: Ops<"SHR">; static readonly OR: Ops<"OR">; static readonly AND: Ops<"AND">; static readonly THREEFRY: Ops<"THREEFRY">; static readonly SUB: Ops<"SUB">; static readonly FDIV: Ops<"FDIV">; static readonly WHERE: Ops<"WHERE">; static readonly MULACC: Ops<"MULACC">; static readonly ASSIGN: Ops<"ASSIGN">; static readonly BIND: Ops<"BIND">; static readonly BARRIER: Ops<"BARRIER">; static readonly RANGE: Ops<"RANGE">; static readonly IF: Ops<"IF">; static readonly ENDRANGE: Ops<"ENDRANGE">; static readonly ENDIF: Ops<"ENDIF">; static readonly VCONST: Ops<"VCONST">; static readonly CONST: Ops<"CONST">; static readonly DEVICE: Ops<"DEVICE">; static readonly MULTI: Ops<"MULTI">; } export declare class GroupOp { static Unary: Ops<"RECIP">[]; static Binary: Ops<"SHL">[]; static Ternary: Ops<"WHERE">[]; static ALU: Ops<"SHL">[]; static Irreducible: Ops<"DEFINE_VAR">[]; static Movement: Ops<"RESHAPE">[]; static Buffer: Ops<"STORE">[]; static Block: Ops<"BLOCK">[]; static Commutative: Ops<"MAX">[]; static Associative: Ops<"MUL">[]; static Idempotent: Ops<"MAX">[]; static UnsafePad: Ops<"RECIP">[]; } export declare const view_supported_devices: string[]; export declare const identity_element: (op: Ops, dt: DType) => ConstType | ConstType[]; export declare const can_pad: (u: UOp, edges: Map, visisted: Set) => boolean; export declare const END_FOR_UOP: Map, Ops<"STORE">[]>; export declare const resolve: (x: ConstType, def?: boolean) => boolean; export declare const smax: (...lst: sint[]) => sint; export declare const smin: (...lst: sint[]) => sint; export declare const ssimplify: (uop: UOp) => UOp; export declare const sym_infer: (uop: sint, varVals: Map) => number; type UOpInput = { op: Ops; dtype?: DType; src?: UOp[]; arg?: any; }; export declare class UOp extends MathTrait { op: Ops; dtype: DType; src: UOp[]; arg?: any | undefined; static cache: WeakValueMap; _id: bigint; children: WeakValueMap; _buf?: Buffer; _metadata?: Metadata; constructor(op: Ops, dtype?: DType, src?: UOp[], arg?: any | undefined, _buffer?: Buffer); toString(): string; replace: (args: Partial) => UOp; get toposort(): Set; private _tuplize; get tuplize(): any[]; private _st; get st(): ShapeTracker | undefined; private _full_shape; get full_shape(): sint[]; get shape(): sint[]; get size(): number; simplify: () => UOp; ssimplify: () => UOp; _eval: void>(dtypes: DType[], expectedType: T) => InstanceType; bool: () => Boolean; int: () => Number; float: () => Number; substitute: (dvars: Map) => UOp; get st_arg(): ShapeTracker; get const_arg(): ConstType; get axis_arg(): number[]; static sink: (...srcs: UOp[]) => UOp; detach: () => UOp; index: (idx: UOp, valid?: UOp) => UOp; const_like: (b: ConstLike) => UOp; broadcast: (count: number) => UOp; cast: (dtype: DType) => UOp; bitcast: (dtype: DType) => UOp; gep: (i: number[] | number) => UOp; load: (src: UOp[], kwargs?: Partial) => UOp; static load: (src: UOp[], kwargs?: Partial) => UOp; store: (src: UOp[], kwargs?: Partial) => UOp; static store: (src: UOp[], kwargs?: Partial) => UOp; alu: (arg: Ops, ...src: UOp[]) => UOp; static const: (dtype: DType, b: ConstLike) => UOp; static int: (b: number) => UOp; static bool: (b: boolean) => UOp; static float: (b: number) => UOp; valid: (st: ShapeTracker) => UOp; static range: (dtype: DType, start: sint, end: sint, idx: number) => UOp; _reduce_op: (op: Ops, axis: number[]) => UOp; r: (op: Ops, axis: number[]) => UOp; assign: (x: UOp) => UOp; contiguous: () => UOp; contiguous_backward: () => UOp; static multi: (more: UOp[], axis?: number, real?: boolean[]) => UOp; get bounds(): Generator<[sint, sint], void, unknown>; get axis(): any; get real(): any; get real_lbs(): UOp[]; shard: (devices: string[], axis?: number) => UOp; static metaop: (op: Ops, shape: sint[], dtype: DType, device: string, arg?: any) => UOp; copy_to_device: (device: string | string[], clone?: boolean) => UOp; clone: () => UOp; get metadata(): Metadata | undefined; get base(): UOp; view: (new_st: ShapeTracker) => UOp; _mop: (op: Ops, arg: any) => UOp; reshape: (arg: sint[]) => UOp; pad: (arg: [sint, sint][]) => UOp; expand: (arg: sint[]) => UOp; permute: (arg: sint[]) => UOp; shrink: (arg: [sint, sint][]) => UOp; stride: (arg: sint[]) => UOp; static buffer_num: Generator; static new_buffer: (device: string, size: number, dtype: DType) => UOp; get device(): string | string[]; private _device; get buf_uop(): UOp; buf_uop_view: () => UOp; static variable: (name: string, minVal?: ConstType, maxVal?: ConstType, dtype?: DType) => UOp; get expr(): any; bind: (val: number) => UOp; unbind: () => [Variable, number]; val: () => number; vars: () => UOp[]; variables: () => Variable[]; /**largest known int that divides this */ constFactor: () => number; divides: (v: number) => UOp | undefined; get vmin(): number | bigint; get vmax(): number | bigint; private _min_max; private get min_max(); private _sym_fxn; sym_infer: (varVals: Map) => number; render: (simplify?: boolean) => string; } export declare class KernelInfo { local_dims: number; upcasted: number; dont_use_locals: boolean; _id: bigint; static cache: WeakValueMap; constructor(local_dims?: number, // number of local dimensions (this is remapping RANGE to SPECIAL) upcasted?: number, // count that are upcasted (this is remapping RANGE to UNROLL) dont_use_locals?: boolean); toString: () => string; } export declare const python_alu: Map, (...x: ConstType[]) => ConstType>; export declare const exec_alu: (op: Ops, dtype: DType, operands: ConstType[], truncateOutput?: boolean) => any; export declare const print_uops: (uops: UOp[]) => void; export type UPatInput = { op?: Ops | Ops[]; dtype?: DType | DType[]; src?: UPat | UPat[] | [UPat[]]; arg?: any; name?: string; allow_any_len?: boolean; location?: any; custom_early_reject?: Ops[]; }; export type UPatFn = (args: Record & { ctx: Ctx; }) => Res; export type Pattern = [UPat, UPatFn]; export declare class UPat extends MathTrait { arg?: any | undefined; name?: string | undefined; custom_early_reject?: Ops[] | undefined; op?: Ops[]; dtype?: DType[]; _in_src?: UPat | UPat[] | [UPat[]]; src?: UPat[][]; allowed_len: number; location: [string, number]; early_reject: Ops[]; fn: (fn: UPatFn) => Pattern; constructor(op?: Ops | Ops[], dtype?: DType | DType[], src?: UPat | UPat[] | [UPat[]], arg?: any | undefined, name?: string | undefined, allow_any_len?: boolean, location?: any, custom_early_reject?: Ops[] | undefined); named: (name?: string) => this; static any: (src: UPat[]) => UPatAny; static var: (name?: string, dtype?: DType | DType[]) => UPat; static cvar: (name?: string, dtype?: DType, vec?: boolean) => UPat; static const: (dtype?: DType | DType[], b?: ConstLike) => UPat; index: (idx: UPat, valid?: UPat) => UPat; static index: (self: UPat, idx: UPat, valid?: UPat) => UPat; view: (st?: ShapeTracker, kwargs?: Partial) => UPat; cast: (dtype?: DType) => UPat; bitcast: (dtype?: DType) => UPat; gep: (i: number) => UPat; load: (src?: UPat[], kwargs?: Partial) => UPat; static load: (src: UPat[], kwargs?: Partial) => UPat; store: (src: UPat[], kwargs?: Partial) => UPat; static store: (src: UPat[], kwargs?: Partial) => UPat; assign: (x: UPat) => UPat; const_like: (b: ConstLike) => UPat; alu: (op: Ops, ...src: UPat[]) => UPat; toString: () => string; match: (uop: UOp, store: Map) => Map[]; } export declare class UPatAny extends UPat { match: (uop: UOp, store: Map) => Map[]; } export declare class PatternMatcher { patterns: [UPat, UPatFn][]; pdict: Map, [UPat, UPatFn, Set>][]>; constructor(patterns: [UPat, UPatFn][]); add: (more: PatternMatcher) => PatternMatcher; rewrite: (uop: UOp, ctx?: any) => Res | undefined; } export declare class TrackedGraphRewrite { loc: [string, number]; sink: UOp; matches: [UOp, UOp, UPat][]; } export declare class TrackedPatternMatcher extends PatternMatcher { rewrite: (uop: UOp, ctx?: Ctx) => UOp | undefined; } export declare const launch_viz: (env_str: string, data: string) => never; export declare class RewriteContext { pm: PatternMatcher; ctx?: Ctx; replace: Map; constructor(pm: PatternMatcher, ctx?: Ctx); top_down_rewrite: (n: UOp) => UOp; bottom_up_rewrite: (n: UOp) => UOp; } export declare const graph_rewrite: (sink: UOp, pm: PatternMatcher, ctx?: Ctx, bottom_up?: boolean) => UOp; export declare const graph_rewrite_map: (sink: UOp, pm: PatternMatcher, ctx?: Ctx, bottom_up?: boolean) => Map; export declare const spec: PatternMatcher; export declare const type_verify: (uops: UOp[], extra_specs?: PatternMatcher[]) => void; export declare function split_uop(x: UOp, sep: Ops): Generator; export declare const div_and_mod_folding: (x: UOp, y: UOp, which: typeof Ops.MOD | typeof Ops.IDIV, split_rem?: boolean) => undefined | UOp; export declare const canonicalize_simplex: (X: UOp) => UOp | undefined; export declare const is_increasing: (f: UOp) => boolean; export declare const parse_valid: (valid: UOp) => [UOp, boolean, number]; export declare const uop_given_valid: (valid: UOp, uop: UOp) => UOp | undefined; export declare const simplify_valid: (valid: UOp) => UOp | undefined; export declare const sint_to_uop: (x: sint, dtype?: DType) => UOp; export declare const symbolic_simple: PatternMatcher; export declare const symbolic: PatternMatcher; export declare const symbolic_flat: PatternMatcher; export declare const _substitute: PatternMatcher, UOp | undefined>; export declare const renderer: PatternMatcher; export type sint = number | UOp; export declare const merge_views: PatternMatcher; export declare const view_left: PatternMatcher; export declare const TRANSCENDENTAL_SUPPORTED_DTYPES: DType[]; /**replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio*/ export declare const _lazy_map_numbers: (x: UOp, inf: UOp, _inf: UOp, nan: UOp, ratio: UOp) => UOp; export declare const mantissa_bits: (d: DType) => number; export declare const exponent_bias: (d: DType) => number; export declare const exponent_mask: (d: DType) => number; export declare const shr: (x: UOp, y: number) => UOp; export declare const shl: (x: UOp, y: number) => UOp; /**round d:float to int away from 0*/ export declare const rintk: (d: UOp) => UOp; /**cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]*/ export declare const pow2if: (q: UOp, float_dtype: DType) => UOp; /**calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf).*/ export declare const ilogb2k: (d: UOp) => UOp; /**d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number.*/ export declare const ldexp3k: (d: UOp, e: UOp) => UOp; /**d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal.*/ export declare const ldexp2k: (d: UOp, e: UOp) => UOp; /** frexp(v) -> (mantissa, exponent) assuming v != 0 */ export declare const frexp: (v: UOp) => [UOp, UOp]; /** * Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where 39800.0 <= d <= +Inf * Returns a tuple of `(r, q)`: * - `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`. * - `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`. */ export declare const payne_hanek_reduction: (d: UOp) => [UOp, UOp]; /** * Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where 0 <= abs(d) <= 39800.0 * Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`. */ export declare const cody_waite_reduction: (d: UOp) => [UOp, UOp]; export declare const trig_poly: (d: UOp, coeff32: number[], coeff64: number[]) => UOp; export declare const sin_poly: (d: UOp) => UOp; export declare const sin_poly_small: (d: UOp, q: UOp) => UOp; export declare const sin_poly_large: (d: UOp, q: UOp) => UOp; /** * Implements a 1.0 ULP approximation for Ops.SIN. * - fast=True assumes x <= switch_over. * - switch_over is the threshold for switching to payne_hanek_reduction. */ export declare const xsin: ({ d, fast, switch_over }: { d: UOp; fast?: boolean; switch_over?: number; }) => UOp; /** * Implements a 1.0 ULP approximation for Ops.EXP2 * Paper: https://arxiv.org/pdf/2001.09258 */ export declare const xexp2: ({ d }: { d: UOp; }) => UOp; /** * Implements a 1.0 ULP approximation for Ops.LOG2 * Paper: https://arxiv.org/pdf/2001.09258 5.5 */ export declare const xlog2: ({ d }: { d: UOp; }) => UOp; export declare const fold_expanded: (ex: UOp, buf: UOp) => UOp | undefined; export declare const fix_unfoldable_image_load: (load: UOp, buf: UOp) => UOp | undefined; export declare const buf_idx_pat: UPat; export declare const float4_folding: PatternMatcher; export declare const simplify_valid_load: (buf: UOp, start_idx: UOp, valid: UOp) => undefined | UOp; export declare const get_late_rewrite_patterns: (ops: Ops[], force_transcendental?: any) => PatternMatcher; export declare const threefry2x32: (x: UOp, key: UOp) => UOp; export declare const sigmoid_like: (x: UOp, y: UOp) => UOp; export declare const loop_collapse: (compval: UOp, multconst: UOp, rng: UOp, acc: UOp, idx2?: UOp, idx3?: UOp, extra?: UOp, vec?: UOp, ne?: UOp, add?: UOp, mul?: UOp) => UOp | undefined; export declare const index_collapse: (idx: UOp, rng: UOp, buf: UOp, ld: UOp, acc: UOp, add?: UOp, mul?: UOp) => UOp | undefined; export declare const gep_through_wmma: (gep: UOp, wmma: UOp) => UOp | undefined; export declare const no_vectorized_wmma: (wmma: UOp) => UOp | undefined; export declare const reduce_collapse: (acc: UOp, ret: UOp, alu: UOp) => UOp | undefined; export declare const acc_pat: UPat, rng_pat: UPat; export declare const rng_aug: UPatAny; export declare const index_load: UPat; export declare const arange_augrng: UPatAny; export declare const arange_m: UPat; export declare const mulacc_unrolled: PatternMatcher; export declare const sym: PatternMatcher; export declare const _expand_arg_to_idx: (args: [number, number][], rpk: Map) => number; export declare const _choices_from_args: (args: [number, number][]) => Map[]; export declare const _swizzle_args: (cargs: [number, number][], eargs: [number, number][], exclude_args: number[]) => number[]; export declare const do_expand: (root: UOp) => UOp | undefined; export declare const do_contract: (con: UOp) => UOp; export declare const no_vectorized_alu: (alu: UOp) => UOp | undefined; export declare const create_gate: (root: UOp) => undefined | UOp; export declare const expander: PatternMatcher; export declare const no_vectorized_load_store: (ls: UOp) => UOp | undefined; export declare const no_vectorized_acc: (acc: UOp) => UOp | undefined; export declare const devectorize: PatternMatcher; export declare const delete_redundant_gates: (buf: UOp, idx: UOp, val: UOp, store_gate: UOp, cast?: UOp) => undefined | UOp; export declare const load_store_indexing: PatternMatcher; export declare const migrate_indexing: PatternMatcher; export declare const move_mask: (x: UOp, buf: UOp, idx: UOp, mask: UOp, cast?: UOp) => UOp; export declare const pm_render: PatternMatcher; export declare const full_graph_rewrite: (sink: UOp, opts?: Renderer) => UOp; export {};