Processing math: 100%

Implementing Matrix Multiplication with WebGPU in Safari

This is a quick overview of how to write a matrix multiplication for Safari leveraging the WebGPU API. This will run on both Macs and iPhones provided WebGPU is enabled.

The benchmarks in this document are done on an M1 chip.

The full code can be found here and a demo here.

Code

First, we'll need an async function to set up our GPU. This is because the API is largely asynchronous and many core methods return promises we'll have to await.

async function run() {
  const adapter = await navigator.gpu.requestAdapter();
  const device = await adapter.requestDevice();
  
  const M = 1024;
  const N = 1024;
  const K = 1024;
  
  const mm = createMatrixMultiplication(device, M, N, K);  
}

window.addEventListener('load', run);

We set up the adapter and a device. Then we scaffold out the invocation of a function that will create an mm function.

Run once code (where perf isn't important)

WebGPU requires some setup that we will want to run only once. The boilerplate includes:

function createMatrixMultiplication(device, M, N, K) {

  // BindGroupLayout

  const visibility = GPUShaderStage.COMPUTE;
  const type = "storage-buffer";

  const bindGroupLayout = device.createBindGroupLayout({
    bindings: [
      { binding: 0, visibility: visibility, type: type },
      { binding: 1, visibility: visibility, type: type },
      { binding: 2, visibility: visibility, type: type },
    ]
  });

  // PipelineLayout

  const pipelineLayout = device.createPipelineLayout({
    bindGroupLayouts: [bindGroupLayout],
  });

  // ComputePipeline

  const source = generateMatrixMultiplicationKernel(M, N, K);

  const computePipeline = device.createComputePipeline({
    layout: pipelineLayout,
    computeStage: {
      module: device.createShaderModule({
        code: source,
      }),
      entryPoint: "main"
    }
  });

  // define the mm function

  function mm(A, B, C) {
     // see below
  }

  return mm;
}

With the "run once" boilerplate out of the way we've exposed two more functions to implement.

Run every time code (where perf matters)

The first is a bit more boilerplate: the mm function.

What we'll do in this function is the work that is required for every invocation. This includes

Instead of immediately invoking the generated command buffer, we will return it. This will free up our ability to chain together command buffers and potentially hide dispatch latency.

  function mm(A, B, C) {
    const commandEncoder = device.createCommandEncoder();
    const bindGroup = device.createBindGroup({
      layout: bindGroupLayout,
      bindings: [
        { binding: 0, resource: { buffer: A, size: M * K * 4 } },
        { binding: 1, resource: { buffer: B, size: N * K * 4 } },
        { binding: 2, resource: { buffer: C, size: M * N * 4 } },
      ]
    });

    const passEncoder = commandEncoder.beginComputePass();
    passEncoder.setPipeline(computePipeline);
    passEncoder.setBindGroup(0, bindGroup);
    passEncoder.dispatch(M / 8, N / 8, 1);
    passEncoder.endPass();
    return commandEncoder.finish();
  }

Note that we multiply the number of elements by 4 because the buffer size is in bytes and we are using floats, which are 4 bytes each.

The second function we'll need to implement is more interesting: generateMatrixMultiplicationKernel. For the sake of simplicity we will hardcode the input sizes. These can always be passed as arguments to the kernel (that'd require more binding code).

The kernel

If you're not familiar with matrix multiplication, the operation is defined as

AB=kamkbkn

where k is the reduction axis.

Computationally, if k is small, we end up with a memory bound operation. If k and m or n are large, we can readily leverage the compute power of our GPU. However, if k is large and m and n are small, we will need to employ some tricks to parallelize the reduction itself. For the sake of this demo, we'll focus on the first two cases.

Before writing the kernel, we should implement a CPU version for correctness checks.

function mm_ref(A, B, C, M, N, K) {
  for (let m = 0; m < M; ++m) {
    for (let n = 0; n < N; ++n) {
      let res = 0;
      for (let k = 0; k < K; ++k) {
        res += A[m * K + k] * B[k * N + n];
      }
      C[m * N + n] = res;
    }
  }
}

And then a working kernel. Even totally unoptimized, it is still reasonably fast compared to the CPU code.

function generateMatrixMultiplicationKernel(M, N, K) {
  return `[numthreads(8, 8, 1)]
        compute void main(device float[] A : register(u0),
                          device float[] B : register(u1),
                          device float[] C : register(u2),
                          float3 threadID : SV_DispatchThreadID) {
            uint m = uint(threadID.x);
            uint n = uint(threadID.y);
            uint c_idx = n + m * ${N};

            float result = 0.0;
            for (uint k = 0; k < ${K}; k++) {
                uint a_idx = k + m * ${K};
                uint b_idx = n + k * ${N};
                result += A[a_idx] * B[b_idx];
            }
            C[c_idx] = result;
        }`
}

There's a fair amount to observe about the source above.

Putting it together

async function run() {
  const adapter = await navigator.gpu.requestAdapter();
  const device = await adapter.requestDevice();

  const M = 1024;
  const N = 1024;
  const K = 1024;

  const mm = createMatrixMultiplication(device, M, N, K);  
  
  const A = randGPU(device, M * K);
  const B = randGPU(device, K * N);
  const C = randGPU(device, M * N);
  
  const t0 = performance.now();
  
  device.getQueue().submit([mm(A, B, C)]);
  const result = await toCPU(device, C, M * N);
  
  const t1 = performance.now();
  
  // log the gflops achieved
  const flops = M * N * K * 2;
  console.log(flops / ((t1 - t0) * 1e6));
}

The goal of the above code is to benchmark the performance of our matrix multiplication implementation. flops is the total number of floating point operations, K multiplications and additions for each one of that M * N elements in the output. performance.now() has a resolution of milliseconds so to calculate gigaflops (billion floating point operations per second), we multiply it by 1e6.

On my machine I see around 160Gflops. This is quite bad compared to the M1 GPU specs which claim it can hit nearly 2.3Tflops. We've got a ways to go!

Before optimizing, let's go over the two functions randGPU and toCPU that allocate and copy memory respectively.

Memory Management

function randGPU(device, numel) {
  const [gpu, cpu] = device.createBufferMapped({
    size: numel * 4, // sizeof float
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
  });
  let rand = new Float32Array(numel);
  for (let i = 0; i < numel; ++i) {
    rand[i] = Math.random() / 500;
  }
  new Float32Array(cpu).set(rand);
  gpu.unmap();
  return gpu;
}

We create a mapped Buffer and populate it with random elements of small size. Calling gpu.unmap() releases the ability for JavaScript to populate the buffer and allows us to use it on the GPU itself.

Unlike allocating data for the GPU, getting data from the GPU to the CPU is asynchronous.

async function toCPU(device, gpu_array, numel) {
  const buffer = device.createBuffer({
    size: numel * 4,
    usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
  });
  const commandEncoder = device.createCommandEncoder();
  commandEncoder.copyBufferToBuffer(gpu_array, 0, buffer, 0, numel * 4);
  device.getQueue().submit([commandEncoder.finish()]);

  return new Float32Array(await buffer.mapReadAsync());
}

We first create a read buffer that can be mapped to CPU and encode the command required to copy another GPU buffer (the output of mm).

After we copy the buffer to buffer we have to await the mapping operation and finally populate a new Float32Array.

Optimizing the kernel

Before we can optimize, we should improve the benchmark a bit. Instead of a single invocation, we should leverage the pipelining ability of asynchronous dispatch.

  const t0 = performance.now();
  
  device.getQueue().submit([
    mm(A, B, C),
    mm(C, B, A),
    mm(A, C, B),
    mm(B, A, C),
    mm(A, B, C),
    mm(C, B, A),
    mm(A, C, B),
    mm(B, A, C),
  ]);
  const result = await toCPU(device, C, M * N);
  
  const t1 = performance.now();
  
  const flops = M * N * K * 2 * 8;
  console.log(flops / ((t1 - t0) * 1e6));

We now do 8 back to back matrix multiplications, so we have to update the flops calculation accordingly.

Vectorized fused multiply-accumulate

The first step to optimization is to ensure the hottest instructions are also the most efficient. Right now our kernel uses scalar multiplication and scalar addition. We should instead use a fused variant, often referred to as an FMA.

WebGPU exposes an approximation of an FMA instruction, mad. Since this single instruction does a * b + c, it might speed up our computation. (Empirically, I didn't measure any noticeable speedup from swapping to mad over a * b + c, which may be the compiler helping us out).

Another potential issue with performance is that our kernel uses scalars instead of vectors. There are many ways to vectorize our code, but for now we can vectorize the contiguous output dimension N. Since matrix A doesn't have dimension N, we can continue to read it as a scalar float[] array. For B and output C, we will want to swap to float4[] arrays.

function generateVectorizedMatrixMultiplicationKernel(M, N, K) {
  return `[numthreads(8, 8, 1)]
        compute void main(device float[] A : register(u0), 
                          device float4[] B : register(u1),
                          device float4[] C : register(u2),
                          float3 threadID : SV_DispatchThreadID) {
            uint m = uint(threadID.x);
            uint n = uint(threadID.y);
            uint c_idx = n + m * ${N/4};

            float4 result = float4(0.0, 0.0, 0.0, 0.0);
            for (uint k = 0; k < ${K}; k++) {
                uint a_idx = k + m * ${K};
                float a_elem = A[a_idx];
                float4 a = float4(a_elem, a_elem, a_elem, a_elem);
                uint b_idx = n + k * ${N/4};
                float4 b = B[b_idx];
                result  = mad(a, b, result);
            }
            C[c_idx] = result;
        }`
}

We will also need to update the dispatch logic to account for the fact that we are reading and writing 4 elements at a time.

function createMatrixMultiplication(device, M, N, K) {
  // ...
  function mm(A, B, C) {
    // ...
    passEncoder.dispatch(M / 8, N / 8 / 4, 1);
    // ...
  }
  
  return mm;
}

We've improved the performance a lot with this small change, going from 160Gflops to over 420Gflops.

But we can do better!

Tiling

Instead of a single vector, we can use multiple vectors.

function generateTiledMatrixMultiplicationKernel(M, N, K) {
  return `[numthreads(8, 8, 1)]
        compute void main(device float4[] A : register(u0),
                          device float4[] B : register(u1),
                          device float4[] C : register(u2),
                          float3 threadID : SV_DispatchThreadID) {
            uint m = uint(threadID.x);
            uint n = uint(threadID.y);
            uint c_idx = (n + m * ${N/4}) * 4;

            float4 result0 = float4(0.0, 0.0, 0.0, 0.0);
            float4 result1 = float4(0.0, 0.0, 0.0, 0.0);
            float4 result2 = float4(0.0, 0.0, 0.0, 0.0);
            float4 result3 = float4(0.0, 0.0, 0.0, 0.0);
            for (uint k = 0; k < ${K}; k++) {
                uint a_idx = k + m * ${K};
                float4 a = A[a_idx];
                float4 a_x = float4(a.x, a.x, a.x, a.x);
                float4 a_y = float4(a.y, a.y, a.y, a.y);
                float4 a_z = float4(a.z, a.z, a.z, a.z);
                float4 a_w = float4(a.w, a.w, a.w, a.w);
                uint b_idx = n + k * ${N/4};
                float4 b = B[b_idx];
                result0 = mad(a_x, b, result0);
                result1 = mad(a_y, b, result1);
                result2 = mad(a_z, b, result2);
                result3 = mad(a_w, b, result3);
            }
            C[c_idx + 0] = result0;
            C[c_idx + 1] = result1;
            C[c_idx + 2] = result2;
            C[c_idx + 3] = result3;
        }`
}

That brings us to around 680Gflops!

More to go

There's more we can do, such as increasing the tile size and unrolling the reduction loop. I'll leave these optimizations as an exercise to the reader.

These types of improvements might also be well handled by generating the source instead of hand-tuning it. In fact, there exist compilers, like TVM, that can tune all of this stuff for us and have been hitting performance much closer to native code.

It's great that the web is starting to adopt native GPU support and it's also really pleasant to be able to quickly hit good performance.