/* * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ import { alignTo, assert } from "../utils"; import { StorageBuffer, type StorageBufferDescriptor } from "./buffer"; import { wasm, ndarrayf, type WasmPtr } from "../wasm"; export type DType = "i8" | "u8" | "i16" | "u16" | "i32" | "u32" | "f32" | "f64"; export type NdarrayResidency = "cpu-webassembly" | "gpu-storagebuffer"; export type NumberTypedArray = Int8Array | Uint8Array | Int16Array | Uint16Array | Int32Array | Uint32Array | Float32Array | Float64Array; export interface NumberTypedArrayConstructor { readonly BYTES_PER_ELEMENT: number; new(buffer: ArrayBufferLike, byteOffset?: number, length?: number): T; } export type DTypeInfo = { readonly dtype: DType; readonly ctor: NumberTypedArrayConstructor; readonly bytesPerElement: number; readonly wgslScalarType: "i32" | "u32" | "f32" | "f64" | null; }; const DTYPE_TABLE: Record = { i8: { dtype: "i8", ctor: Int8Array, bytesPerElement: 1, wgslScalarType: null }, u8: { dtype: "u8", ctor: Uint8Array, bytesPerElement: 1, wgslScalarType: null }, i16: { dtype: "i16", ctor: Int16Array, bytesPerElement: 2, wgslScalarType: null }, u16: { dtype: "u16", ctor: Uint16Array, bytesPerElement: 2, wgslScalarType: null }, i32: { dtype: "i32", ctor: Int32Array, bytesPerElement: 4, wgslScalarType: "i32" }, u32: { dtype: "u32", ctor: Uint32Array, bytesPerElement: 4, wgslScalarType: "u32" }, f32: { dtype: "f32", ctor: Float32Array, bytesPerElement: 4, wgslScalarType: "f32" }, f64: { dtype: "f64", ctor: Float64Array, bytesPerElement: 8, wgslScalarType: "f64" } }; export const dtypeInfo = (dtype: DType): DTypeInfo => { const info = DTYPE_TABLE[dtype]; if (!info) throw new Error(`Unknown dtype: ${String(dtype)}`); return info; }; export type NdLayoutDescriptor = { shape: ReadonlyArray; stridesBytes?: ReadonlyArray; offsetBytes?: number; }; const validateShape = (shape: ReadonlyArray): number[] => { assert(Array.isArray(shape), "shape must be an array of dimension sizes"); const out: number[] = new Array(shape.length); for (let i = 0; i < shape.length; i++) { const d = shape[i] as number; assert(Number.isInteger(d) && d >= 0, `shape[${i}] must be an integer >= 0 (got ${d})`); out[i] = d; } return out; }; const defaultRowMajorStridesBytes = (shape: ReadonlyArray, bytesPerElement: number): number[] => { const ndim = shape.length; const strides: number[] = new Array(ndim); let stride = bytesPerElement; for (let i = ndim - 1; i >= 0; i--) { assert(Number.isInteger(stride) && stride >= 0, "Stride overflow while computing row-major strides"); assert(stride <= 0x7FFFFFFF, `row-major stride exceeds i32 range (got ${stride})`); strides[i] = stride; stride = stride * shape[i]!; assert(Number.isFinite(stride) && stride >= 0, "Stride overflow while computing row-major strides"); assert(stride <= Number.MAX_SAFE_INTEGER, "Stride overflow while computing row-major strides"); } return strides; }; const validateStridesBytes = (stridesBytes: ReadonlyArray, ndim: number, bytesPerElement: number): number[] => { assert(Array.isArray(stridesBytes), "stridesBytes must be an array"); assert(stridesBytes.length === ndim, `stridesBytes length (${stridesBytes.length}) must equal shape length (${ndim})`); const out: number[] = new Array(ndim); for (let i = 0; i < ndim; i++) { const s = stridesBytes[i] as number; assert(Number.isInteger(s), `stridesBytes[${i}] must be an integer (got ${s})`); assert(s >= -0x80000000 && s <= 0x7FFFFFFF, `stridesBytes[${i}] must fit in i32 (got ${s})`); assert((s % bytesPerElement) === 0, `stridesBytes[${i}] (${s}) must be a multiple of bytesPerElement (${bytesPerElement})`); out[i] = s; } return out; }; const validateOffsetBytes = (offsetBytes: number | undefined, bytesPerElement: number): number => { const off = offsetBytes ?? 0; assert(Number.isInteger(off) && off >= 0, `offsetBytes must be an integer >= 0 (got ${off})`); assert((off % bytesPerElement) === 0, `offsetBytes (${off}) must be a multiple of bytesPerElement (${bytesPerElement})`); return off; }; const numelFromShape = (shape: ReadonlyArray): number => { let n = 1; for (let i = 0; i < shape.length; i++) { n *= shape[i]!; if (shape[i]! === 0) return 0; assert(Number.isFinite(n), "numel overflow"); } return n; }; const requiredBackingBytes = (shape: ReadonlyArray, stridesBytes: ReadonlyArray, offsetBytes: number, bytesPerElement: number): number => { if (shape.length === 0) { const req = offsetBytes + bytesPerElement; assert(req <= 0xFFFFFFFF, `required backing bytes exceeds wasm32 address space (got ${req})`); return req; } if (numelFromShape(shape) === 0) return 0; let min = BigInt(offsetBytes); let max = BigInt(offsetBytes); for (let i = 0; i < shape.length; i++) { const dim = shape[i]!; const s = BigInt(stridesBytes[i]!); const extent = BigInt(dim - 1) * s; if (extent < 0n) min += extent; else max += extent; } assert(min >= 0n, `layout underflows: minimum byte offset is ${min} (offsetBytes is too small for negative strides)`); const req = max + BigInt(bytesPerElement); assert(req <= BigInt(0xFFFFFFFF), `required backing bytes exceeds wasm32 address space (got ${req})`); return Number(req); }; const isContiguousRowMajor = (shape: ReadonlyArray, stridesBytes: ReadonlyArray, offsetBytes: number, bytesPerElement: number): boolean => { if (offsetBytes !== 0) return false; if (shape.length === 0) return true; if (numelFromShape(shape) === 0) return true; const expected = defaultRowMajorStridesBytes(shape, bytesPerElement); for (let i = 0; i < shape.length; i++) if (stridesBytes[i] !== expected[i]) return false; return true; }; export abstract class Ndarray { readonly dtype: DType; readonly shape: number[]; readonly stridesBytes: number[]; readonly offsetBytes: number; readonly bytesPerElement: number; readonly numel: number; readonly byteLength: number; protected constructor(dtype: DType, shape: number[], stridesBytes: number[], offsetBytes: number, byteLength: number) { this.dtype = dtype; this.shape = shape; this.stridesBytes = stridesBytes; this.offsetBytes = offsetBytes; this.bytesPerElement = dtypeInfo(dtype).bytesPerElement; this.numel = numelFromShape(shape); this.byteLength = byteLength; } get ndim(): number { return this.shape.length; } get wgslScalarType(): DTypeInfo["wgslScalarType"] { return dtypeInfo(this.dtype).wgslScalarType; } get isContiguousC(): boolean { return isContiguousRowMajor(this.shape, this.stridesBytes, this.offsetBytes, this.bytesPerElement); } layout(): { shape: number[]; stridesBytes: number[]; offsetBytes: number } { return { shape: this.shape.slice(), stridesBytes: this.stridesBytes.slice(), offsetBytes: this.offsetBytes }; } abstract get residency(): NdarrayResidency; } const ND_ERROR = 0xFFFF_FFFF; export class CPUndarray extends Ndarray { readonly basePtrBytes: WasmPtr; readonly shapePtr: WasmPtr; readonly stridesPtr: WasmPtr; private indicesPtr: WasmPtr; private _buf: ArrayBuffer | null = null; private _all: NumberTypedArray | null = null; private constructor(dtype: DType, shape: number[], stridesBytes: number[], offsetBytes: number, byteLength: number, basePtrBytes: WasmPtr, shapePtr: WasmPtr, stridesPtr: WasmPtr, indicesPtr: WasmPtr) { super(dtype, shape, stridesBytes, offsetBytes, byteLength); this.basePtrBytes = basePtrBytes; this.shapePtr = shapePtr; this.stridesPtr = stridesPtr; this.indicesPtr = indicesPtr; } static empty(dtype: DType, layout: NdLayoutDescriptor): CPUndarray { wasm.memory(); const info = dtypeInfo(dtype); const shape = validateShape(layout.shape); const offsetBytes = validateOffsetBytes(layout.offsetBytes, info.bytesPerElement); const stridesBytes = layout.stridesBytes ? validateStridesBytes(layout.stridesBytes, shape.length, info.bytesPerElement) : defaultRowMajorStridesBytes(shape, info.bytesPerElement); const byteLength = requiredBackingBytes(shape, stridesBytes, offsetBytes, info.bytesPerElement); const shapePtr = wasm.allocU32(shape.length >>> 0); const stridesPtr = wasm.allocU32(shape.length >>> 0); const indicesPtr = wasm.allocU32(shape.length >>> 0); const shapeView = wasm.u32view(shapePtr, shape.length >>> 0); for (let i = 0; i < shape.length; i++) shapeView[i] = shape[i]! >>> 0; const strideView = wasm.i32view(stridesPtr, shape.length >>> 0); for (let i = 0; i < stridesBytes.length; i++) strideView[i] = stridesBytes[i]! | 0; const basePtrBytes = (byteLength > 0) ? wasm.allocBytes(byteLength >>> 0) : 0; return new CPUndarray(dtype, shape, stridesBytes, offsetBytes, byteLength, basePtrBytes, shapePtr, stridesPtr, indicesPtr); } static zeros(dtype: DType, layout: NdLayoutDescriptor): CPUndarray { const a = CPUndarray.empty(dtype, layout); a.zero_(); return a; } static fromArray>(dtype: DType, shape: ReadonlyArray, src: T): CPUndarray { const dst = CPUndarray.empty(dtype, { shape }); assert(dst.isContiguousC, "CPUndarray.fromArray currently requires a contiguous row-major layout"); assert(src.length >= dst.numel, `source length (${src.length}) must be >= numel (${dst.numel})`); const data = dst.data(); for (let i = 0; i < dst.numel; i++) data[i] = src[i] as number; return dst; } get residency(): NdarrayResidency { return "cpu-webassembly"; } private ensureAllView(): NumberTypedArray { const buf = wasm.memory().buffer as unknown as ArrayBuffer; if (this._buf !== buf) { this._buf = buf; const ctor = dtypeInfo(this.dtype).ctor; this._all = new ctor(buf) as NumberTypedArray; } return this._all!; } backingBytes(): Uint8Array { if (this.byteLength === 0) return new Uint8Array(wasm.memory().buffer as unknown as ArrayBuffer, 0, 0) as unknown as Uint8Array; return wasm.u8view(this.basePtrBytes, this.byteLength >>> 0); } data(): NumberTypedArray { assert(this.isContiguousC, "CPUndarray.data() requires a contiguous row-major layout (use backingBytes() for raw backing storage)"); if (this.numel === 0) { const buf = wasm.memory().buffer as unknown as ArrayBuffer; const ctor = dtypeInfo(this.dtype).ctor; return new ctor(buf, 0, 0) as NumberTypedArray; } return new (dtypeInfo(this.dtype).ctor)(wasm.memory().buffer as unknown as ArrayBuffer, (this.basePtrBytes + this.offsetBytes) >>> 0, this.numel >>> 0) as NumberTypedArray; } private offsetBytesAt(indices: ReadonlyArray): number { assert(indices.length === this.ndim, `expected ${this.ndim} indices, got ${indices.length}`); if (this.ndim === 0) return this.offsetBytes; const idxView = wasm.u32view(this.indicesPtr, this.ndim >>> 0); for (let i = 0; i < this.ndim; i++) { const v = indices[i] as number; assert(Number.isInteger(v) && v >= 0, `index[${i}] must be an integer >= 0 (got ${v})`); idxView[i] = v >>> 0; } const off = ndarrayf.offsetBytes(this.shapePtr, this.stridesPtr, this.indicesPtr, this.ndim >>> 0, this.offsetBytes >>> 0); assert(off !== ND_ERROR, "index out of bounds (or offset overflow)"); assert(off + this.bytesPerElement <= this.byteLength, "computed byte offset is outside backing storage"); return off; } get(...indices: number[]): number { const off = this.offsetBytesAt(indices); const abs = (this.basePtrBytes + off) >>> 0; assert((abs % this.bytesPerElement) === 0, "internal error: misaligned element address"); const i = abs / this.bytesPerElement; const all = this.ensureAllView() as any; return all[i] as number; } set(value: number, ...indices: number[]): void { const off = this.offsetBytesAt(indices); const abs = (this.basePtrBytes + off) >>> 0; assert((abs % this.bytesPerElement) === 0, "internal error: misaligned element address"); const i = abs / this.bytesPerElement; const all = this.ensureAllView() as any; all[i] = value; } zero_(): void { if (this.byteLength === 0) return; this.backingBytes().fill(0); } uploadToGPU(ctx: { device: GPUDevice; queue: GPUQueue }, desc: Omit = {}): GPUndarray { const bytes = this.backingBytes(); const sb = new StorageBuffer(ctx.device, ctx.queue, { label: desc.label, byteLength: this.byteLength, data: bytes, copyDst: desc.copyDst, copySrc: (desc as any).copySrc, usage: desc.usage }); return new GPUndarray(this.dtype, this.shape.slice(), this.stridesBytes.slice(), this.offsetBytes, this.byteLength, sb, 0, true); } } export class GPUndarray extends Ndarray { readonly buffer: StorageBuffer; readonly baseOffsetBytes: number; private readonly owned: boolean; constructor(dtype: DType, shape: number[], stridesBytes: number[], offsetBytes: number, byteLength: number, buffer: StorageBuffer, baseOffsetBytes: number = 0, owned: boolean = false) { super(dtype, shape, stridesBytes, offsetBytes, byteLength); assert(Number.isInteger(baseOffsetBytes) && baseOffsetBytes >= 0, `baseOffsetBytes must be an integer >= 0 (got ${baseOffsetBytes})`); assert((baseOffsetBytes & 3) === 0, `baseOffsetBytes must be 4-byte aligned for storage buffers (got ${baseOffsetBytes})`); this.buffer = buffer; this.baseOffsetBytes = baseOffsetBytes; this.owned = owned; } static empty(ctx: { device: GPUDevice; queue: GPUQueue }, dtype: DType, layout: NdLayoutDescriptor, desc: Omit = {}): GPUndarray { const info = dtypeInfo(dtype); const shape = validateShape(layout.shape); const offsetBytes = validateOffsetBytes(layout.offsetBytes, info.bytesPerElement); const stridesBytes = layout.stridesBytes ? validateStridesBytes(layout.stridesBytes, shape.length, info.bytesPerElement) : defaultRowMajorStridesBytes(shape, info.bytesPerElement); const byteLength = requiredBackingBytes(shape, stridesBytes, offsetBytes, info.bytesPerElement); const sb = new StorageBuffer(ctx.device, ctx.queue, { label: desc.label, byteLength, copyDst: desc.copyDst, copySrc: (desc as any).copySrc, usage: desc.usage }); return new GPUndarray(dtype, shape, stridesBytes, offsetBytes, byteLength, sb, 0, true); } static wrap(buffer: StorageBuffer, dtype: DType, layout: NdLayoutDescriptor, baseOffsetBytes: number = 0): GPUndarray { const info = dtypeInfo(dtype); const shape = validateShape(layout.shape); const offsetBytes = validateOffsetBytes(layout.offsetBytes, info.bytesPerElement); const stridesBytes = layout.stridesBytes ? validateStridesBytes(layout.stridesBytes, shape.length, info.bytesPerElement) : defaultRowMajorStridesBytes(shape, info.bytesPerElement); const byteLength = requiredBackingBytes(shape, stridesBytes, offsetBytes, info.bytesPerElement); return new GPUndarray(dtype, shape, stridesBytes, offsetBytes, byteLength, buffer, baseOffsetBytes, false); } get residency(): NdarrayResidency { return "gpu-storagebuffer"; } bindingResource(): { buffer: StorageBuffer; offset: number; size: number } { return { buffer: this.buffer, offset: this.baseOffsetBytes, size: alignTo(this.byteLength, 4) }; } async readbackToCPU(): Promise { assert(this.buffer.canReadback, "GPUndarray.readbackToCPU() requires the underlying StorageBuffer to be created with copySrc: true"); const bytes = await this.buffer.read(this.baseOffsetBytes, this.byteLength); const cpu = CPUndarray.empty(this.dtype, { shape: this.shape, stridesBytes: this.stridesBytes, offsetBytes: this.offsetBytes }); cpu.backingBytes().set(new Uint8Array(bytes), 0); return cpu; } destroy(): void { if (this.owned) this.buffer.destroy(); } }