///
import { HasReactive, reactively } from "@reactively/decorate";
import {
assignParams,
createDebugBuffer,
gpuTiming,
limitWorkgroupLength,
reactiveTrackUse,
trackContext,
} from "thimbleberry";
import { computePipeline } from "../util/ComputePipeline.js";
import { calcDispatchSizes } from "../util/DispatchSizes.js";
import { Cache, ComposableShader } from "../util/Util.js";
import wgsl from "./ApplyScanBlocks.wgsl?raw";
import { BinOpModule } from "../util/BinOpModules.js";
import { ModuleRegistry } from "wgsl-linker";
import { sumU32 } from "../binop/BinOpModuleSumU32.js";
/** @internal */
export interface ApplyScanBlocksArgs {
device: GPUDevice;
partialScan: GPUBuffer;
blockSums: GPUBuffer;
forceWorkgroupLength?: number;
forceMaxWorkgroups?: number | undefined;
label?: string;
binOps?: BinOpModule;
exclusiveLarge?: boolean;
initialValue?: number;
partialScanOffset?: number;
blockSumsOffset?: number;
scanOffset?: number;
pipelineCache?: () => Cache;
}
const defaults: Partial = {
binOps: sumU32,
label: "",
exclusiveLarge: false,
initialValue: undefined,
forceMaxWorkgroups: undefined,
partialScanOffset: 0,
blockSumsOffset: 0,
scanOffset: 0,
};
/** Shader stage used in a prefix scan, applies block summaries to block elements
* @internal
*/
export class ApplyScanBlocks extends HasReactive implements ComposableShader {
@reactively partialScan!: GPUBuffer;
@reactively blockSums!: GPUBuffer;
@reactively forceWorkgroupLength?: number;
@reactively binOps!: BinOpModule;
@reactively label!: string;
@reactively exclusiveLarge!: boolean;
@reactively initialValue!: number;
@reactively partialScanOffset!: number;
@reactively blockSumsOffset!: number;
@reactively scanOffset!: number;
private forceMaxWorkgroups?: number;
private device!: GPUDevice;
private usageContext = trackContext();
private pipelineCache?: () => Cache;
constructor(params: ApplyScanBlocksArgs) {
super();
assignParams(this, params, defaults);
}
commands(commandEncoder: GPUCommandEncoder): void {
this.updateUniforms();
const bindGroups = this.bindGroups;
this.dispatchSizes.forEach((dispatchSize, i) => {
const dispatchLabel = `${this.label} ${dispatchSize} ${i}`;
const timestampWrites = gpuTiming?.timestampWrites(dispatchLabel);
const passEncoder = commandEncoder.beginComputePass({ timestampWrites });
passEncoder.label = ` ${this.label} apply scan blocks`;
passEncoder.setPipeline(this.pipeline);
passEncoder.setBindGroup(0, bindGroups[i]);
passEncoder.dispatchWorkgroups(dispatchSize, 1, 1);
passEncoder.end();
});
}
destroy(): void {
this.usageContext.finish();
}
@reactively private get partialScanSize(): number {
return this.partialScan.size;
}
/** Return enough dispatches to cover the source
* `(multiple dispatches are needed for large sources) */
@reactively private get dispatchSizes(): number[] {
const sourceElems = this.partialScanSize / Uint32Array.BYTES_PER_ELEMENT;
const maxWorkgroups = this.maxWorkgroups;
return calcDispatchSizes(sourceElems, this.workgroupLength, maxWorkgroups);
}
@reactively private get maxWorkgroups(): number {
return this.forceMaxWorkgroups ?? this.device.limits.maxComputeWorkgroupsPerDimension;
}
@reactively private get registry(): ModuleRegistry {
return new ModuleRegistry({ wgsl: { main: wgsl }, rawWgsl: [this.binOps.wgsl] });
}
@reactively private get pipeline(): GPUComputePipeline {
const compute = computePipeline(
{
device: this.device,
label: this.label,
registry: this.registry,
constants: {
workgroupSizeX: this.workgroupLength,
},
bindings: [
{ buffer: { type: "uniform" } },
{ buffer: { type: "read-only-storage" } },
{ buffer: { type: "read-only-storage" } },
{ buffer: { type: "storage" } },
],
debugBuffer: true,
},
this.pipelineCache,
);
return compute.pipeline;
}
@reactively private get bindGroups(): GPUBindGroup[] {
return this.dispatchSizes.map((_, i) => this.createBindGroup(i));
}
private createBindGroup(index: number): GPUBindGroup {
return this.device.createBindGroup({
label: `${this.label} ${index} apply scan blocks`,
layout: this.pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.uniforms[index] } },
{ binding: 1, resource: { buffer: this.partialScan } },
{ binding: 2, resource: { buffer: this.blockSums } },
{ binding: 3, resource: { buffer: this.result } },
{ binding: 11, resource: { buffer: this.debugBuffer } },
],
});
}
@reactively get result(): GPUBuffer {
const buffer = this.device.createBuffer({
label: `${this.label} apply scan blocks result`,
size: this.partialScanSize,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
reactiveTrackUse(buffer, this.usageContext);
return buffer;
}
@reactively private get workgroupLength(): number {
return limitWorkgroupLength(this.device, this.forceWorkgroupLength);
}
@reactively get debugBuffer(): GPUBuffer {
const buffer = createDebugBuffer(this.device, "ApplyScanBlocks debug");
reactiveTrackUse(buffer, this.usageContext);
return buffer;
}
// TODO use one uniform buffer, with dynamic offsets instead
@reactively private get uniforms(): GPUBuffer[] {
return this.dispatchSizes.map((_, i) => this.uniformsBuffer(i));
}
private uniformsBuffer(index: number): GPUBuffer {
const buffer = this.device.createBuffer({
label: `${this.label} ${index} apply scan blocks uniforms`,
size: Uint32Array.BYTES_PER_ELEMENT * 8,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
reactiveTrackUse(buffer, this.usageContext);
return buffer;
}
@reactively private updateUniforms(): void {
let partialScanOffset = this.partialScanOffset;
let scanOffset = this.scanOffset;
let blockSumsOffset = this.blockSumsOffset;
const uniforms = this.uniforms;
this.dispatchSizes.map((dispatchSize, i) => {
this.writeUniforms(uniforms[i], partialScanOffset, scanOffset, blockSumsOffset);
partialScanOffset += dispatchSize * this.workgroupLength;
scanOffset += dispatchSize * this.workgroupLength;
blockSumsOffset += dispatchSize;
});
}
private writeUniforms(
uniforms: GPUBuffer,
partialScanOffset: number,
scanOffset: number,
blockSumsOffset: number,
): void {
const exclusive = this.exclusiveLarge ? 1 : 0;
const initialValue = this.initialValue;
const array = new Uint32Array([
partialScanOffset,
scanOffset,
blockSumsOffset,
exclusive,
initialValue,
]);
this.device.queue.writeBuffer(uniforms, 0, array);
}
}