// luma.gl // SPDX-License-Identifier: MIT // Copyright (c) vis.gl contributors import type { Bindings, BindingsByGroup, ComputeShaderLayout, ShaderLayout } from '../adapter/types/shader-layout'; import type {Device} from '../adapter/device'; import type {ComputePipeline} from '../adapter/resources/compute-pipeline'; import type {RenderPipeline} from '../adapter/resources/render-pipeline'; import {normalizeBindingsByGroup} from '../adapter-utils/bind-groups'; type AnyPipeline = RenderPipeline | ComputePipeline; type AnyShaderLayout = ShaderLayout | ComputeShaderLayout; type BindGroupCacheKeys = Partial>; type BindGroupMap = Partial>; type LayoutCache = Partial>; type LayoutBindGroupCache = { bindGroupsBySource: WeakMap; emptyBindGroup?: unknown; }; export class BindGroupFactory { readonly device: Device; private readonly _layoutCacheByPipeline: WeakMap = new WeakMap(); private readonly _bindGroupCacheByLayout: WeakMap = new WeakMap(); constructor(device: Device) { this.device = device; } getBindGroups( pipeline: AnyPipeline, bindings?: Bindings | BindingsByGroup, bindGroupCacheKeys?: BindGroupCacheKeys ): BindGroupMap { if (this.device.type !== 'webgpu' || pipeline.shaderLayout.bindings.length === 0) { return {}; } const bindingsByGroup = normalizeBindingsByGroup(pipeline.shaderLayout, bindings); const resolvedBindGroups: BindGroupMap = {}; for (const group of getBindGroupIndicesUpToMax(pipeline.shaderLayout.bindings)) { const groupBindings = bindingsByGroup[group]; const bindGroupLayout = this._getBindGroupLayout(pipeline, group); const bindGroupLabel = getBindGroupLabel(pipeline, pipeline.shaderLayout, group); if (!groupBindings || Object.keys(groupBindings).length === 0) { if (!hasBindingsInGroup(pipeline.shaderLayout.bindings, group)) { resolvedBindGroups[group] = this._getEmptyBindGroup( bindGroupLayout, pipeline.shaderLayout, group, bindGroupLabel ); } continue; } const bindGroupCacheKey = bindGroupCacheKeys?.[group]; if (bindGroupCacheKey) { const layoutCache = this._getLayoutBindGroupCache(bindGroupLayout); if (layoutCache.bindGroupsBySource.has(bindGroupCacheKey)) { resolvedBindGroups[group] = layoutCache.bindGroupsBySource.get(bindGroupCacheKey) || null; continue; } const bindGroup = this.device._createBindGroupWebGPU( bindGroupLayout, pipeline.shaderLayout, groupBindings, group, bindGroupLabel ); layoutCache.bindGroupsBySource.set(bindGroupCacheKey, bindGroup); resolvedBindGroups[group] = bindGroup; } else { resolvedBindGroups[group] = this.device._createBindGroupWebGPU( bindGroupLayout, pipeline.shaderLayout, groupBindings, group, bindGroupLabel ); } } return resolvedBindGroups; } private _getBindGroupLayout(pipeline: AnyPipeline, group: number): object { let layoutCache = this._layoutCacheByPipeline.get(pipeline); if (!layoutCache) { layoutCache = {}; this._layoutCacheByPipeline.set(pipeline, layoutCache); } layoutCache[group] ||= this.device._createBindGroupLayoutWebGPU(pipeline, group) as object; return layoutCache[group]; } private _getEmptyBindGroup( bindGroupLayout: object, shaderLayout: AnyShaderLayout, group: number, label: string ): unknown { const layoutCache = this._getLayoutBindGroupCache(bindGroupLayout); layoutCache.emptyBindGroup ||= this.device._createBindGroupWebGPU(bindGroupLayout, shaderLayout, {}, group, label) || null; return layoutCache.emptyBindGroup; } private _getLayoutBindGroupCache(bindGroupLayout: object): LayoutBindGroupCache { let layoutCache = this._bindGroupCacheByLayout.get(bindGroupLayout); if (!layoutCache) { layoutCache = {bindGroupsBySource: new WeakMap()}; this._bindGroupCacheByLayout.set(bindGroupLayout, layoutCache); } return layoutCache; } } export function _getDefaultBindGroupFactory(device: Device): BindGroupFactory { device._factories.bindGroupFactory ||= new BindGroupFactory(device); return device._factories.bindGroupFactory as BindGroupFactory; } function getBindGroupIndicesUpToMax(bindings: AnyShaderLayout['bindings']): number[] { const maxGroup = bindings.reduce( (highestGroup, binding) => Math.max(highestGroup, binding.group), -1 ); return Array.from({length: maxGroup + 1}, (_, group) => group); } function hasBindingsInGroup(bindings: AnyShaderLayout['bindings'], group: number): boolean { return bindings.some(binding => binding.group === group); } function getBindGroupLabel( pipeline: AnyPipeline, shaderLayout: AnyShaderLayout, group: number ): string { const bindingNames = shaderLayout.bindings .filter(binding => binding.group === group) .sort((left, right) => left.location - right.location) .map(binding => binding.name); const bindingSuffix = bindingNames.length > 0 ? bindingNames.join(',') : 'empty'; return `${pipeline.id}/group${group}[${bindingSuffix}]`; }