//
// The WebGPU API implements several asynchronous methods.
// So it is necessary to call them from an asynchronous function.
//
async setupGPU() {
    if (!navigator.gpu) {
        system.log("WebGPU is not supported. Enable chrome://flags/#enable-unsafe-webgpu flag.")
        return()
    }
    
    // Get the GPU divice.
    adapter ?= navigator.gpu.requestAdapter()
    device ?= adapter.requestDevice()
    
    //
    // Sets CPU matrices data.
    //
    
    // [rows, columns, data...]
    firstMatrix := Float32Array([
        2, 2,
        1, 2,
        3, 4])
    secondMatrix := Float32Array([
        2, 2,
        5, 6,
        7, 8])
    
    //
    // Sets GPU matrices data.
    //
    
    // Create the matrices in GPU.
    gpuBufferFirstMatrix = device.createBuffer({
        "mappedAtCreation": true,
        "size": firstMatrix.byteLength,
        "usage": GPUBufferUsage.STORAGE})
    gpuBufferSecondMatrix = device.createBuffer({
        "mappedAtCreation": true,
        "size": secondMatrix.byteLength,
        "usage": GPUBufferUsage.STORAGE})
    // Creates the matrix for the results.
    resultMatrixBufferSize = Float32Array.BYTES_PER_ELEMENT * (2 + firstMatrix[0] * secondMatrix[1])
    resultMatrixBuffer = device.createBuffer({
        "size": resultMatrixBufferSize,
        "usage": GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC})
        
    // Copies the data from the arrays on the first CPU matrix to the GPU.
    arrayBufferFirstMatrix = gpuBufferFirstMatrix.getMappedRange()
    float32ArrayBufferFirstMatrix := Float32Array(arrayBufferFirstMatrix)
    float32ArrayBufferFirstMatrix.set(firstMatrix)
    gpuBufferFirstMatrix.unmap()
    
    // Copies the data from the arrays on the second CPU matrix to the GPU.
    arrayBufferSecondMatrix = gpuBufferSecondMatrix.getMappedRange()
    float32ArrayBufferSecondMatrix := Float32Array(arrayBufferSecondMatrix)
    float32ArrayBufferSecondMatrix.set(secondMatrix)
    gpuBufferSecondMatrix.unmap()
    
    //
    // Defines and compiles the compute shader.
    //
    
    // Defines the dialect of the shader language.
    type = glslang.EShLangCompute
    // Bind group layout and bind group

    // Defines the memory layout for the arguments received by the compute shader.
    bindGroupLayout = device.createBindGroupLayout({
        "entries": [
            {"binding": 0, "visibility": GPUShaderStage.COMPUTE, "type": "readonly-storage-buffer"},
            {"binding": 1, "visibility": GPUShaderStage.COMPUTE, "type": "readonly-storage-buffer"},
            {"binding": 2, "visibility": GPUShaderStage.COMPUTE, "type": "storage-buffer"}]})
    bindGroup = device.createBindGroup({
        "layout": bindGroupLayout,
        "entries": [
            {"binding": 0, "resource": {"buffer": gpuBufferFirstMatrix}},
            {"binding": 1, "resource": {"buffer": gpuBufferSecondMatrix}},
            {"binding": 2, "resource": {"buffer": resultMatrixBuffer}}]})
  
    // Defines the compute shader.
    source = "#version 450\n" +
    "layout(std430, set = 0, binding = 0) readonly buffer FirstMatrix {\n" +
    "    vec2 size;\n" +
    "    float numbers[];\n" +
    "} firstMatrix;\n" +
    "layout(std430, set = 0, binding = 1) readonly buffer SecondMatrix {\n" +
    "    vec2 size;\n" +
    "    float numbers[];\n" +
    "} secondMatrix;\n" +
    "layout(std430, set = 0, binding = 2) buffer ResultMatrix {\n" +
    "    vec2 size;\n" +
    "    float numbers[];\n" +
    "} resultMatrix;\n" +
    "void main() {\n" +
    "    resultMatrix.size = vec2(firstMatrix.size.x, secondMatrix.size.y);\n" +
    "    ivec2 resultCell = ivec2(gl_GlobalInvocationID.x, gl_GlobalInvocationID.y);\n" +
    "    float result = 0.0;\n" +
    "    for (int i = 0; i < firstMatrix.size.y; i++) {\n" +
    "        int a = i + resultCell.x * int(firstMatrix.size.y);\n" +
    "        int b = resultCell.y + i * int(secondMatrix.size.y);\n" +
    "        result += firstMatrix.numbers[a] * secondMatrix.numbers[b];\n" +
    "    }\n" +
    "    int index = resultCell.y + resultCell.x * int(secondMatrix.size.y);\n" +
    "    resultMatrix.numbers[index] = result;\n" +
    "}"

    // Imports the GLSL to SPIR-V compiler.
    glslangModule ?= import("https://unpkg.com/@webgpu/glslang@0.0.15/dist/web-devel/glslang.js")
    glslang ?= glslangModule.default()
    // Compiles the GLSL code.
    wgslCode = glslang.compileGLSL(source, "compute")
    system.log("GLSL code: " + JSON.stringify(source))
    system.log("SPIR-V bytecodes: " + JSON.stringify(wgslCode))
    
    // Pipeline setup.
    computePipeline = device.createComputePipeline({
        "layout": device.createPipelineLayout({"bindGroupLayouts": [bindGroupLayout]}),
        "computeStage": {"module": device.createShaderModule({"code": wgslCode}), "entryPoint": "main"}})
    
    // Creates an encoder to send commands to the GPU.
    commandEncoder = device.createCommandEncoder()
    passEncoder = commandEncoder.beginComputePass()
    passEncoder.setPipeline(computePipeline)
    passEncoder.setBindGroup(0, bindGroup)
    // Defines the dimensions of the processing unit grid.
    passEncoder.dispatch(firstMatrix[0], secondMatrix[1])
    passEncoder.endPass()
    
    // Get a GPU buffer for reading in an unmapped state.
    gpuReadBuffer = device.createBuffer({
        "size": resultMatrixBufferSize,
        "usage": GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ})
        
    // Encode commands for copying buffer to buffer.
    commandEncoder.copyBufferToBuffer(resultMatrixBuffer, 0, gpuReadBuffer, 0, resultMatrixBufferSize)
    
    // Submit GPU commands.
    gpuCommands = commandEncoder.finish()
    device.defaultQueue.submit([gpuCommands])

    // Read buffer.
    bufferRead ?= gpuReadBuffer.mapAsync(GPUMapMode.READ)
    arrayBuffer = gpuReadBuffer.getMappedRange()
    float32ArrayBuffer := Float32Array(arrayBuffer)
    system.log("Matrices multiplication: " + float32ArrayBuffer)
}

// Sends the kernel and computing data to the GPU asynchronously.
setupGPU()
