// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import {env} from 'onnxruntime-common'; import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; // TODO: support: // - ceil_mode "test_maxpool_2d_ceil" // - storage_order "test_maxpool_with_argmax_2d_precomputed_strides" // - [MaxPool] dilations "test_maxpool_2d_dilations" // - [MaxPool] output[1] "test_maxpool_with_argmax_2d_precomputed_pads" const validateInputs = (inputs: readonly TensorView[]): void => { if (env.webgpu.validateInputContent && (!inputs || inputs.length !== 1)) { throw new Error('Pool ops requires 1 input.'); } }; const getAdjustedPoolAttributesAndOutputShape = ( input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => { const isChannelsLast = attributes.format === 'NHWC'; const inputShapeAsChannelFirst = input.dims.slice(); if (isChannelsLast) { inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position. } const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); const kernelShape = attributes.kernelShape.slice(); const strides = attributes.strides.slice(); const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : []; const pads = attributes.pads.slice(); PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads); const outputShapeAsChannelFirst = PoolConvUtil.computePoolOutputShape( isGlobalOperator, inputShapeAsChannelFirst, strides, dilations, kernelShape, pads, attributes.autoPad); const newAttributes = Object.assign({}, attributes); if (hasDilations) { Object.assign(newAttributes, {kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey}); } else { Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey}); } const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice(); outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]); return [newAttributes, isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst]; }; const getUniformAndPadInfo = ( outputShape: readonly number[], attributes: AttributeType): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => { const isChannelsLast = attributes.format === 'NHWC'; const outputSize = ShapeUtil.size(outputShape); const kernelSize = ShapeUtil.size(attributes.kernelShape); const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: kernelSize}]; const uniforms: UniformsArrayType = [{name: 'outputSize', type: 'u32'}, {name: 'kernelSize', type: 'u32'}]; if (attributes.kernelShape.length <= 2) { const kw = attributes.kernelShape[attributes.kernelShape.length - 1]; const sw = attributes.strides[attributes.strides.length - 1]; const pwStart = attributes.pads[attributes.pads.length / 2 - 1]; const pwEnd = attributes.pads[attributes.pads.length - 1]; const pwStartEndNotZero = !!(pwStart + pwEnd); programUniforms.push( {type: DataType.uint32, data: kw}, {type: DataType.uint32, data: sw}, {type: DataType.uint32, data: pwStart}, {type: DataType.uint32, data: pwEnd}, ); uniforms.push( {name: 'kw', type: 'u32'}, {name: 'sw', type: 'u32'}, {name: 'pwStart', type: 'u32'}, {name: 'pwEnd', type: 'u32'}); let phStartEndNotZero = false; if (attributes.kernelShape.length === 2) { const kh = attributes.kernelShape[attributes.kernelShape.length - 2]; const sh = attributes.strides[attributes.strides.length - 2]; const phStart = attributes.pads[attributes.pads.length / 2 - 2]; const phEnd = attributes.pads[attributes.pads.length - 2]; phStartEndNotZero = !!(phStart + phEnd); programUniforms.push( {type: DataType.uint32, data: kh}, {type: DataType.uint32, data: sh}, {type: DataType.uint32, data: phStart}, {type: DataType.uint32, data: phEnd}); uniforms.push( {name: 'kh', type: 'u32'}, {name: 'sh', type: 'u32'}, {name: 'phStart', type: 'u32'}, {name: 'phEnd', type: 'u32'}); } return [programUniforms, uniforms, true, pwStartEndNotZero, phStartEndNotZero]; } else { if (isChannelsLast) { throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.'); } const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape); programUniforms.push( {type: DataType.uint32, data: kernelStrides}, {type: DataType.uint32, data: attributes.pads}, {type: DataType.uint32, data: attributes.strides}); uniforms.push( {name: 'kernelStrides', type: 'u32', length: kernelStrides.length}, {name: 'pads', type: 'u32', length: attributes.pads.length}, {name: 'strides', type: 'u32', length: attributes.strides.length}); const hasPads = attributes.pads.reduce((sum, cur) => sum + cur); return [programUniforms, uniforms, !!hasPads, false, false]; } }; const generatePoolingCode = ( shaderHelper: ShaderHelper, x: IndicesHelper, rank: number, outputShapeRank: number, attributes: AttributeType, op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEndNotZero: boolean, phStartEndNotZero: boolean): string => { const isChannelsLast = attributes.format === 'NHWC'; const dataType = x.type.value; const output = outputVariable('output', x.type.tensor, outputShapeRank); if (attributes.kernelShape.length <= 2) { let codeW = ''; let codeH = ''; let codeHEnd = ''; const dimIdxW = rank - (isChannelsLast ? 2 : 1); if (pwStartEndNotZero) { codeW = ` for (var i: u32 = 0u; i < uniforms.kw; i++) { xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i; if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= uniforms.x_shape[${dimIdxW}]) { pad++; continue; } let x_val = x[${x.indicesToOffset('xIndices')}]; ${op1} }`; } else { codeW = ` for (var i: u32 = 0u; i < uniforms.kw; i++) { xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i; let x_val = x[${x.indicesToOffset('xIndices')}]; ${op1} }`; } if (attributes.kernelShape.length === 2) { const dimIdxH = rank - (isChannelsLast ? 3 : 2); if (phStartEndNotZero) { codeH = ` for (var j: u32 = 0u; j < uniforms.kh; j++) { xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j; if (xIndices[${dimIdxH}] < 0 || xIndices[${dimIdxH}] >= uniforms.x_shape[${dimIdxH}]) { pad += i32(uniforms.kw); continue; } `; } else { codeH = ` for (var j: u32 = 0u; j < uniforms.kh; j++) { xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j; `; } codeHEnd = ` } `; } const poolingCode = ` ${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let indices = ${output.offsetToIndices('global_idx')}; var xIndices = ${output.offsetToIndices('global_idx')}; var value = ${dataType}(${start}); var pad = 0; ${codeH} ${codeW} ${codeHEnd} ${op2} output[global_idx] = value; }`; return poolingCode; } else { if (isChannelsLast) { throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.'); } const stridesRank = attributes.kernelShape.length; const padsRank = attributes.pads.length; let padCode = ''; if (hasPads) { padCode = ` if (xIndices[j] >= uniforms.x_shape[j]) { pad++; isPad = true; break; } } if (!isPad) { let x_val = x[${x.indicesToOffset('xIndices')}]; ${op1} }`; } else { padCode = ` } let x_val = x[${x.indicesToOffset('xIndices')}]; ${op1} `; } const poolingCode = ` ${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let indices = ${output.offsetToIndices('global_idx')}; var xIndices = ${output.offsetToIndices('global_idx')}; var offsets: array; var value = ${dataType}(${start}); var pad = 0; var isPad = false; for (var i: u32 = 0u; i < uniforms.kernelSize; i++) { var offset = i; for (var j = 0u; j < ${stridesRank - 1}u; j++) { offsets[j] = offset / ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)}; offset -= offsets[j] * ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)}; } offsets[${stridesRank - 1}] = offset; isPad = false; for (var j = ${rank - stridesRank}u; j < ${rank}u; j++) { xIndices[j] = indices[j] * ${ getElementAt('uniforms.strides', `j - ${rank - stridesRank}u`, stridesRank)} + offsets[j - ${rank - stridesRank}u] - ${getElementAt('uniforms.pads', 'j - 2u', padsRank)}; ${padCode} } ${op2} output[global_idx] = value; }`; return poolingCode; } }; export interface FormatAttributes { readonly format: 'NHWC'|'NCHW'; } export interface PoolCommonAttributes extends FormatAttributes { readonly autoPad: string; readonly ceilMode: number; readonly kernelShape: readonly number[]; readonly strides: readonly number[]; readonly pads: readonly number[]; } const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string => (`${attributes.format};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`); const createAveragePoolShaderKeyFromAttributes = (attributes: AveragePoolAttributes): string => (`${createShaderKeyFromAttributes(attributes)};${attributes.countIncludePad}`); const createMaxPoolShaderKeyFromAttributes = (attributes: MaxPoolAttributes): string => (`${createShaderKeyFromAttributes(attributes)};${attributes.storageOrder};${attributes.dilations}`); const parsePoolCommonAttributes = (attributes: Record): PoolCommonAttributes => ({ format: attributes.format as FormatAttributes['format'], autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number], ceilMode: attributes.ceil_mode as number, kernelShape: attributes.kernel_shape as [number, number], strides: attributes.strides as [number, number], pads: attributes.pads as [number, number, number, number] }); export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey { readonly countIncludePad: boolean; } const createAveragePoolProgramInfo = (name: string, input: TensorView, isGlobalOperator: boolean, attributes: AveragePoolAttributes): ProgramInfo => { const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); const x = inputVariable('x', input.dataType, input.dims.length); const dataType = x.type.value; const op1 = 'value += x_val;'; let op2 = ''; if (adjustedAttributes.countIncludePad) { op2 += `value /= ${dataType}(uniforms.kernelSize);`; } else { op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`; } const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(outputShape, adjustedAttributes); programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; return { name, shaderCache: {hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: input.dataType}], dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, programUniforms }), getShaderSource: shaderHelper => generatePoolingCode( shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, 0.0, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero), }; }; export const parseAveragePoolAttributes = (attributes: Record): AveragePoolAttributes => { const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true; const attr = parsePoolCommonAttributes(attributes); // TODO: support attribute 'ceil_mode' if (attr.ceilMode !== 0) { throw new Error('using ceil() in shape computation is not yet supported for AveragePool'); } const averagePoolAttributes = {countIncludePad, ...attr, cacheKey: ''}; return {...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes)}; }; export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => { validateInputs(context.inputs); context.compute(createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, attributes)); }; const globalPoolAttributes = { autoPad: '', ceilMode: 0, countIncludePad: false, kernelShape: [], strides: [], pads: [], storageOrder: 0, dilations: [] }; export const parseGlobalAveragePoolAttributes = (attributes: Record): AveragePoolAttributes => { const format = attributes.format as FormatAttributes['format']; return {format, ...globalPoolAttributes, cacheKey: format}; }; export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => { validateInputs(context.inputs); context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes)); }; export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey { readonly storageOrder: number; readonly dilations: number[]; } const createMaxPoolProgramInfo = (name: string, input: TensorView, isGlobalOperator: boolean, attributes: MaxPoolAttributes): ProgramInfo => { const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator); const op1 = ` value = max(x_val, value); `; const op2 = ''; const x = inputVariable('x', input.dataType, input.dims.length); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(outputShape, adjustedAttributes); programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); return { name, shaderCache: {hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: input.dataType}], dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}, programUniforms }), getShaderSource: shaderHelper => generatePoolingCode( shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, (input.dataType === DataType.float16) ? -65504 : -1e5, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero), }; }; export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => { validateInputs(context.inputs); context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, attributes)); }; export const parseMaxPoolAttributes = (attributes: Record): MaxPoolAttributes => { const storageOrder = attributes.storage_order as number; const dilations = attributes.dilations as [number, number]; const attr = parsePoolCommonAttributes(attributes); // TODO: support attribute 'ceil_mode' and 'storage_order' if (storageOrder !== 0) { throw new Error('column major storage order is not yet supported for MaxPool'); } if (attr.ceilMode !== 0) { throw new Error('using ceil() in shape computation is not yet supported for MaxPool'); } const maxPoolAttributes = {storageOrder, dilations, ...attr, cacheKey: ''}; return {...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes)}; }; export const parseGlobalMaxPoolAttributes = (attributes: Record): MaxPoolAttributes => { const format = attributes.format as FormatAttributes['format']; return {format, ...globalPoolAttributes, cacheKey: format}; }; export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => { validateInputs(context.inputs); context.compute(createMaxPoolProgramInfo('GlobalMaxPool', context.inputs[0], true, attributes)); };