// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface SplitAttributes extends AttributeWithCacheKey { readonly axis: number; readonly numOutputs: number; readonly splitSizes: number[]; } const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } }; const createSplitAttributesFromInputs = (inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => { const splitSizes: number[] = []; let numOutputs: number = attributes.numOutputs; if (inputs[1].dims[0] > 0) { inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); numOutputs = splitSizes.length; } return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes}); }; const calculateOutputIndexImpl = (numberOfTensors: number): string => ` fn calculateOutputIndex(index: u32) -> u32 { for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { if (index < ${getElementAt('uniforms.size_in_split_axis', 'i', numberOfTensors)}) { return i; } } return ${numberOfTensors}u; }`; const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { const numberOfTensors = outputs.length; const codeLines: string[] = []; for (let i = 0; i < numberOfTensors; ++i) { const returnSnippet = outputs[i].setByIndices('indices', 'input[global_idx]'); if (numberOfTensors === 1) { codeLines.push(returnSnippet); } else if (i === 0) { codeLines.push(`if (output_number == ${i}u) { ${returnSnippet} }`); } else if (i === numberOfTensors - 1) { codeLines.push(`else { ${returnSnippet} }`); } else { codeLines.push(`else if (output_number == ${i}) { ${returnSnippet} }`); } } return ` fn writeBufferData(output_number: u32, indices: ${outputs[0].type.indices}, global_idx: u32) { ${codeLines.join('\n')} }`; }; const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const dataType = inputs[0].dataType; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); const input = inputVariable('input', dataType, inputShape.length); const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; sizeInSplitAxis[i] = previousSum; const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } programUniforms.push( {type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes)); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${ shaderHelper.registerUniform('input_size', 'u32') .registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length) .declareVariables(input, ...outputs)} ${calculateOutputIndexImpl(sizeInSplitAxis.length)} ${writeBufferDataImpl(outputs)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.input_size')} var indices = ${input.offsetToIndices('global_idx')}; var index = ${input.indicesGet('indices', axis)}; let output_number = calculateOutputIndex(index); if (output_number != 0) { index -= ${getElementAt('uniforms.size_in_split_axis', 'output_number - 1u', sizeInSplitAxis.length)}; ${input.indicesSet('indices', axis, 'index')}; } writeBufferData(output_number, indices, global_idx); }`; return { name: 'Split', shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: outputsTensorInfo, dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, programUniforms }) }; }; export const split = (context: ComputeContext, attributes: SplitAttributes): void => { validateInputs(context.inputs); const updatedAttributes = context.inputs.length === 1 ? attributes : createSplitAttributesFromInputs(context.inputs, attributes); context.compute(createSplitProgramInfo(context.inputs, updatedAttributes), {inputs: [0]}); }; export const parseSplitAttributes = (attributes: Record): SplitAttributes => { const axis = attributes.axis as number; const splitSizes: number[] = attributes.splitSizes as number[]; const numOutputs = attributes.numOutputs as number < 0 ? splitSizes.length : attributes.numOutputs as number; if (numOutputs !== splitSizes.length) { throw new Error('numOutputs and splitSizes lengh must be equal'); } return createAttributeWithCacheKey({axis, numOutputs, splitSizes}); };