// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; import {attention} from './ops/attention'; import {batchNorm} from './ops/batch-norm'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; import {cumsum, parseCumSumAttributes} from './ops/cumsum'; import {depthToSpace, parseDepthToSpaceAttributes} from './ops/depth-to-space'; import {einsum, parseEinsumAttributes} from './ops/einsum'; import {expand} from './ops/expand'; import {fastGelu} from './ops/fast-gelu'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {groupQueryAttention, parseGroupQueryAttentionAttributes} from './ops/group-query-attention'; import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multihead-attention'; import {pad} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {rotaryEmbedding} from './ops/rotary-embedding'; import {skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; import {where} from './ops/where'; import {ComputeContext} from './types'; export type RunFunction = (context: ComputeContext, attribute?: unknown) => void; export type ParseAttributeFunction = (attributeRaw: unknown) => unknown; export type OperatorImplementation = [RunFunction]|[RunFunction, ParseAttributeFunction]; export const WEBGPU_OP_RESOLVE_RULES: Map = new Map([ ['Abs', [unaryOps.abs]], ['Acos', [unaryOps.acos]], ['Acosh', [unaryOps.acosh]], ['Add', [binaryOps.add]], ['ArgMax', [argMax, parseArgMinMaxAttributes]], ['ArgMin', [argMin, parseArgMinMaxAttributes]], ['Asin', [unaryOps.asin]], ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], ['Attention', [attention]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], ['BatchNormalization', [batchNorm]], ['BiasAdd', [biasAdd]], ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], ['Ceil', [unaryOps.ceil]], ['Clip', [unaryOps.clip]], ['Concat', [concat, parseConcatAttributes]], ['Conv', [conv, parseConvAttributes]], ['ConvTranspose', [convTranspose, parseConvTransposeAttributes]], ['Cos', [unaryOps.cos]], ['Cosh', [unaryOps.cosh]], ['CumSum', [cumsum, parseCumSumAttributes]], ['DepthToSpace', [depthToSpace, parseDepthToSpaceAttributes]], ['Div', [binaryOps.div]], ['Einsum', [einsum, parseEinsumAttributes]], ['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]], ['Equal', [binaryOps.equal]], ['Erf', [unaryOps.erf]], ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['FastGelu', [fastGelu]], ['Floor', [unaryOps.floor]], ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['Gelu', [unaryOps.gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], ['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]], ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['Less', [binaryOps.less]], ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], ['Pad', [pad]], ['Pow', [binaryOps.pow]], ['QuickGelu', [unaryOps.quickgelu, unaryOps.parseAlphaAttributes]], ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin]], ['ReduceMean', [reduceMean]], ['ReduceMax', [reduceMax]], ['ReduceSum', [reduceSum]], ['ReduceProd', [reduceProd]], ['ReduceL1', [reduceL1]], ['ReduceL2', [reduceL2]], ['ReduceLogSum', [reduceLogSum]], ['ReduceLogSumExp', [reduceLogSumExp]], ['ReduceSumSquare', [reduceSumSquare]], ['Relu', [unaryOps.relu]], ['Resize', [resize, parseResizeAttributes]], ['RotaryEmbedding', [rotaryEmbedding]], ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], ['Slice', [slice, parseSliceAttributes]], ['SkipLayerNormalization', [skipLayerNorm]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], ['Softmax', [softmax, parseSoftmaxAttributes]], ['Sub', [binaryOps.sub]], ['Tan', [unaryOps.tan]], ['Tanh', [unaryOps.tanh]], ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], ['Where', [where]], ]);