// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; import {ComputeContext} from '../types'; import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu'; import {ConvAttributes} from './conv'; import {parseInternalActivationAttributes} from './fuse-utils'; import {createTransposeProgramInfo} from './transpose'; const computeTotalPad = (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => (inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize; const distributePadding = (totalPad: number, autoPad: string, pads: number[], head: number, tail: number) => { const smallPad = Math.floor(totalPad / 2); if (autoPad === 'SAME_UPPER') { pads[head] = smallPad; pads[tail] = totalPad - smallPad; } else if (autoPad === 'SAME_LOWER') { pads[head] = totalPad - smallPad; pads[tail] = smallPad; } }; const calculateOutputShapeAndPads = (inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], autoPad: string, group: number, pads: number[], strides: readonly number[], isChannelLast: boolean, outputPadding: number[], outputShape: number[]) => { const spatialRank = inputShape.length - 2; const updateOutputShape = outputShape.length === 0; if (outputPadding.length === 0) { for (let i = 0; i < spatialRank; ++i) { outputPadding.push(0); } } const batchSize = inputShape[0]; const outChannels = kernelShape[isChannelLast ? 3 : 1] * group; for (let i = 0, j = inputShape.length - spatialRank - (isChannelLast ? 1 : 0); i < spatialRank; ++i, ++j) { const inSize = inputShape[j]; const outSize = updateOutputShape ? inSize * strides[i] : outputShape[i]; const totalPad = computeTotalPad(inSize, strides[i], pads[i], kernelShape[j], dilations[i], outSize); distributePadding(totalPad, autoPad, pads, i, i + spatialRank); if (updateOutputShape) { outputShape.push( strides[i] * (inSize - 1) + outputPadding[i] + (kernelShape[j] - 1) * dilations[i] + 1 - pads[i] - pads[i + spatialRank]); } } outputShape.splice(0, 0, batchSize); outputShape.splice(isChannelLast ? 3 : 1, 0, outChannels); }; export interface ConvTransposeAttributes extends ConvAttributes { readonly outputPadding: readonly number[]; readonly outputShape: readonly number[]; } const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { kernelShape.length = 0; for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); } } const isChannelsLast = attributes.format === 'NHWC'; kernelShape.splice(0, 0, inputs[1].dims[0]); kernelShape.splice(isChannelsLast ? 3 : 1, 0, inputs[1].dims[1]); const pads = attributes.pads.slice(); const outputShape = attributes.outputShape.slice(); const outputPadding = attributes.outputPadding.slice(); const inputShape = inputs[0].dims; let dilations = attributes.dilations.slice(); if (dilations.reduce((a, b) => a + b, 0) === 0) { const spatialRank = inputs[0].dims.length - 2; dilations = new Array(spatialRank).fill(1); } let strides = attributes.strides.slice(); if (strides.reduce((a, b) => a + b, 0) === 0) { const spatialRank = inputs[0].dims.length - 2; strides = new Array(spatialRank).fill(1); } // If outputShape is not specified in the attributes of this op, infer it from the parameters // Similarly, automatically infer pads if not specified calculateOutputShapeAndPads( inputShape, kernelShape, dilations, attributes.autoPad, attributes.group, pads, strides, isChannelsLast, outputPadding, outputShape); // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); return newAttributes; }; export const parseConvTransposeAttributes = (attributes: Record): ConvTransposeAttributes => { const activationAttributes = parseInternalActivationAttributes(attributes); // TODO : Make this generic enough to compute default attributes for multi-dimensional conv const format = attributes.format as 'NHWC' | 'NCHW'; const autoPad = ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][typeof attributes.autoPad == 'undefined' ? 0 : attributes.autoPad as number]; const dilations = attributes.dilations as [number, number]; const group = attributes.group as number; const kernelShape = attributes.kernelShape as [number, number]; const pads = attributes.pads as [number, number, number, number]; const strides = attributes.strides as [number, number]; const wIsConst = (attributes.wIsConst as () => boolean)(); const outputPadding = attributes.outputPadding as [number, number, number, number]; const outputShape = attributes.outputShape as [number, number]; return { autoPad, format, dilations, group, kernelShape, outputPadding, outputShape, pads, strides, wIsConst, ...activationAttributes, cacheKey: `${attributes.format};${activationAttributes.activation};` }; }; const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { // Refer to the below link for all input checks // https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) { throw new Error('Conv requires 2 or 3 inputs'); } // TODO : Need to add support for multi-dimensional conv if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) { throw new Error('currently only support 2-dimensional conv'); } if (inputs[0].dims.length !== inputs[1].dims.length) { throw new Error('filter does not have same dimension as input'); } // FILTER_IN_CHANNEL should be equal to DATA_CHANNEL const dataChannel = inputs[0].dims[attributes.format === 'NHWC' ? inputs[0].dims.length - 1 : 1]; const filterInChannel = inputs[1].dims[0]; if (dataChannel !== filterInChannel) { throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL'); } const featureMaps = inputs[1].dims[1] * attributes.group; // if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[2].dims[0] !== featureMaps)) { throw new Error('invalid bias'); } const spatialRank = inputs[0].dims.length - 2; const dilationsSet = attributes.dilations.reduce((a, b) => a + b, 0) > 0; // wrong dilations dimension if (dilationsSet && attributes.dilations.length !== spatialRank) { throw new Error(`dilations should be ${spatialRank}D`); } const stridesSet = attributes.strides.reduce((a, b) => a + b, 0) > 0; // Wrong strides dimension if (stridesSet && attributes.strides.length !== spatialRank) { throw new Error(`strides should be ${spatialRank}D`); } // Wrong pads dimension const padsSet = attributes.pads.reduce((a, b) => a + b, 0) > 0; if (padsSet && attributes.pads.length !== spatialRank * 2) { throw new Error(`pads should be ${spatialRank * 2}D`); } // Wrong output padding dimension if (attributes.outputPadding.length !== spatialRank && attributes.outputPadding.length !== 0) { throw new Error(`output_padding should be ${spatialRank}D`); } // if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor // (the first 2 dims are batch_size and channels) const kernelShapeSet = attributes.kernelShape.reduce((a, b) => a + b, 0) > 0; if (kernelShapeSet && attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { throw new Error('invalid kernel shape'); } // as with kernelShape, must have same number of spatial dims as input if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) { throw new Error('invalid output shape'); } }; // for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] const weightTransposePerm = [2, 3, 1, 0]; const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); const isChannelsLast = attributes.format === 'NHWC'; const outputShape = adjustedAttributes.outputShape; const outChannels = outputShape[isChannelsLast ? 3 : 1]; const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit // utilization rate is very low. if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) { context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes)); return; } const outHeight = outputShape[isChannelsLast ? 1 : 2]; const outWidth = outputShape[isChannelsLast ? 2 : 3]; const weightHeight = inputs[1].dims[2]; const weightWidth = inputs[1].dims[3]; const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; const dimInner = weightHeight * weightWidth * inputChannels; const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( createTransposeProgramInfo(inputs[1], weightTransposePerm), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; } // STEP.2: prepare reshaped inputs const convTransposeInputs = [inputs[0], transposedWeight]; const hasBias = inputs.length === 3; if (hasBias) { if (!isChannelsLast && inputs[2].dims.length === 1) { convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); } else { convTransposeInputs.push(inputs[2]); } } // STEP.3: compute matmul context.compute( createConv2DTransposeMatMulProgramInfo( convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, sequentialAccessByThreads), {inputs: convTransposeInputs}); }; const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension const isChannelLast = attributes.format === 'NHWC'; const inputs = [ context.inputs[0].reshape( isChannelLast ? // [N, W, C] -> [N, H=1, W, C] [context.inputs[0].dims[0], 1, context.inputs[0].dims[1], context.inputs[0].dims[2]] : // [N, C, W] -> [N, C, H=1, W] [context.inputs[0].dims[0], context.inputs[0].dims[1], 1, context.inputs[0].dims[2]]), //[FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kW] -> [FILTER_OUT_CHANNEL, FILTER_IN_CHANNEL, kH=1, kW] context.inputs[1].reshape([context.inputs[1].dims[0], context.inputs[1].dims[1], 1, context.inputs[1].dims[2]]) ]; if (context.inputs.length === 3) { inputs.push(context.inputs[2]); } let kernelShape = attributes.kernelShape; if (kernelShape.length === 0 || kernelShape[0] === 0) { kernelShape = [context.inputs[1].dims[2]]; } let dilations = attributes.dilations; if (dilations.length === 0 || dilations[0] === 0) { dilations = [1]; } let strides = attributes.strides; if (strides.length === 0 || strides[0] === 0) { strides = [1]; } let pads = attributes.pads; if (pads.length === 0) { pads = [0, 0]; } pads = [0, pads[0], 0, pads[1]]; strides = [1].concat(strides); dilations = [1].concat(dilations); kernelShape = [1].concat(kernelShape); const adjustedAttributes = getAdjustedConvTransposeAttributes({...attributes, pads, strides, dilations, kernelShape}, inputs); context.compute(createConvTranspose2DProgramInfo( inputs, adjustedAttributes, outputShape => isChannelLast ? [outputShape[0], outputShape[2], outputShape[3]] : [outputShape[0], outputShape[1], outputShape[3]])); }; export const convTranspose = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 3) { convTranspose1d(context, attributes); } else { convTranspose2d(context, context.inputs, attributes); } };