// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; export interface InstanceNormAttributes { epsilon: number; format: 'NHWC'|'NCHW'; } const createInstanceNormProgramInfo = (inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { const xShape = inputs[0].dims; const outputShape = xShape; const axis = 2; const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); const components = getMaxComponents(normSize); const normPackedSize = normSize / components; const inputShape = [xShape[0], xShape[1], normPackedSize]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); const variables = [x, scale, bias, output]; const dataType = x.type.value; const f32Type = components === 1 ? 'f32' : `vec${components}`; const workgroupSize = 64; const uniforms: UniformsArrayType = [{name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}]; return ` var meanShared : f32; var squaredNormShared : f32; var workgroupShared : array<${f32Type}, ${workgroupSize}>; const workgroupSize = ${workgroupSize}u; ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart(workgroupSize)} let norm = global_idx / workgroupSize; let batch = norm / uniforms.x_shape[1]; let channel = norm % uniforms.x_shape[1]; let localIndex = local_id.x; // initialize workgroup memory var initial = ${f32Type}(0); for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); } workgroupShared[localIndex] = initial; workgroupBarrier(); // Calculate the mean of current channel data. for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) { if (localIndex < currSize) { workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize]; } workgroupBarrier(); } if (localIndex == 0) { meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize); } workgroupBarrier(); // reinitialize workgroup memory. initial = ${f32Type}(0); for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); initial = initial + deviation * deviation; } workgroupShared[localIndex] = initial; workgroupBarrier(); // Calculate the sum of square of deviation of current channel data. for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) { if (localIndex < currSize) { workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize]; } workgroupBarrier(); } if (localIndex == 0) { squaredNormShared = ${sumVector('workgroupShared[0]', components)}; } workgroupBarrier(); let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon})); let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ f32Type}(channelShift)); ${output.set('batch', 'channel', 'h', 'value')}; } }`; }; return { ...{name: 'InstanceNormalization'}, // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. shaderCache: {hint: `${attributes.epsilon};${components}`, inputDependencies}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, ], dispatchGroup: {x: normCount}, programUniforms }), getShaderSource, }; }; const computeMean = (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number, epsilon: number) => { const components = getMaxComponents(c); const WG = 64; // we will store channel scale and channel shift in [2, components] matrix // or in vec2 when components == 1 const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`; const sumCastType = components === 1 ? 'f32' : `vec${components}f`; const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`; const unitsOfWork = n * c / components; const wgSize = Math.ceil(h / WG); const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; const meanProgramUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h}, {type: DataType.uint32, data: Math.floor(c / components)}, {type: DataType.uint32, data: Math.floor(h * c / components)} ]; const getMeanShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = inputVariable('input', input.dataType, input.dims, components); return ` ${shaderHelper.declareVariables(inputHelper)} @group(0) @binding(1) var output : array<${outputType}>; struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32}; @group(0) @binding(2) var uniforms: Uniforms; ${shaderHelper.mainStart(WG)} let currentImageNumber = global_idx / ${WG} / uniforms.C; let currentChannelNumber = (global_idx / ${WG}) % uniforms.C; let wgOffset = local_id.x * uniforms.wg_size; if (wgOffset >= uniforms.H) { return; } let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H); let offset = currentImageNumber * uniforms.image_size + currentChannelNumber; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; for (var i: u32 = wgOffset; i < wgMax; i++) { let value = ${sumCastType}(input[offset + i * uniforms.C]); sum += value; squaredSum += value * value; } output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; }; const meanValues = context.compute( { name: 'InstanceNormComputeMean', shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies}, getRunData: () => ({ outputs: [ {dims: [n, c, WG, 2], dataType: DataType.float}, ], dispatchGroup: {x: n * c / components}, programUniforms: meanProgramUniforms }), getShaderSource: getMeanShaderSource, }, {inputs: [input], outputs: [-1]})[0]; const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h}, {type: DataType.uint32, data: Math.floor(c / components)}, {type: DataType.uint32, data: Math.floor(WG * c / components)} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; const getShaderSource = (shaderHelper: ShaderHelper) => { const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); return ` @group(0) @binding(0) var input : array<${outputType}>; @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; @group(0) @binding(3) var output : array<${outputType}>; struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32}; @group(0) @binding(4) var uniforms: Uniforms; ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')} let currentImageNumber = global_idx / uniforms.C; let currentChannelNumber = global_idx % uniforms.C; let offset = currentImageNumber * uniforms.image_size; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) { let value = input[offset + i + currentChannelNumber * ${WG}]; sum += value[0]; squaredSum += value[1]; } sum = sum / f32(uniforms.H); squaredSum = squaredSum / f32(uniforms.H); let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon})); let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; }`; }; return context.compute( { name: 'InstanceNormComputeChannelScaleShift', // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. shaderCache: {hint: `${components};${epsilon}`, inputDependencies}, getRunData: () => ({ outputs: [ {dims: [n, c, 2], dataType: DataType.float}, ], dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}, programUniforms }), getShaderSource, }, {inputs: [meanValues, scale, bias], outputs: [-1]})[0]; }; const createInstanceNormNHWCProgramInfo = (context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => { const xShape = inputs[0].dims; const outputShape = xShape; const N = xShape[0]; const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; const components = getMaxComponents(C); const outputSize = ShapeUtil.size(outputShape) / components; const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); const getShaderSource = (shaderHelper: ShaderHelper) => { const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); return ` @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; @group(0) @binding(1) var scaleInput : array<${scaleType}>; @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; struct Uniforms {H: u32, C : u32}; @group(0) @binding(3) var uniforms: Uniforms; ${shaderHelper.mainStart()} let currentImageNumber = global_idx / (uniforms.C * uniforms.H); let currentChannelNumber = global_idx % uniforms.C; let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber; let scale = scaleInput[scaleOffset]; output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); }`; }; context.compute( { name: 'InstanceNormalizationNHWC', shaderCache: {hint: `${components}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms }), getShaderSource, }, {inputs: [inputs[0], channelScaleShift]}); }; export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => { if (attributes.format === 'NHWC') { createInstanceNormNHWCProgramInfo(context, context.inputs, attributes); } else { context.compute(createInstanceNormProgramInfo(context.inputs, attributes)); } };