## Implementing Matrix Multiplication with WebGPU in Safari This is a quick overview of how to write a matrix multiplication for Safari leveraging the WebGPU API. This will run on both Macs and iPhones provided [WebGPU is enabled](https://mil-tokyo.github.io/webdnn/docs/tips/enable_webgpu_ios.html). The benchmarks in this document are done on an M1 chip. The full code can be found [here](https://jott.live/code/webgpu_mm.js) and a demo [here](https://jott.live/html/webgpu_demo.html). ## Code First, we'll need an `async` function to set up our GPU. This is because the API is largely asynchronous and many core methods return promises we'll have to `await`. ``` async function run() { const adapter = await navigator.gpu.requestAdapter(); const device = await adapter.requestDevice(); const M = 1024; const N = 1024; const K = 1024; const mm = createMatrixMultiplication(device, M, N, K); } window.addEventListener('load', run); ``` We set up the [adapter](https://gpuweb.github.io/gpuweb/#adapters) and a [device](https://gpuweb.github.io/gpuweb/#devices). Then we scaffold out the invocation of a function that will create an `mm` function. ### Run once code (where perf isn't important) WebGPU requires some setup that we will want to run only once. The boilerplate includes: - creating a [BindGroupLayout](https://gpuweb.github.io/gpuweb/#gpubindgrouplayout) that will define how resources can be accessed in the GPU kernel. - creating a [PipelineLayout](https://gpuweb.github.io/gpuweb/#pipeline-layout) that will define a mapping between the [BindGroups](https://gpuweb.github.io/gpuweb/#gpu-bind-group) and the render or compute shaders. - creating a [ComputePipeline](https://gpuweb.github.io/gpuweb/#compute-pipeline) that will compile the kernel source. ``` function createMatrixMultiplication(device, M, N, K) { // BindGroupLayout const visibility = GPUShaderStage.COMPUTE; const type = "storage-buffer"; const bindGroupLayout = device.createBindGroupLayout({ bindings: [ { binding: 0, visibility: visibility, type: type }, { binding: 1, visibility: visibility, type: type }, { binding: 2, visibility: visibility, type: type }, ] }); // PipelineLayout const pipelineLayout = device.createPipelineLayout({ bindGroupLayouts: [bindGroupLayout], }); // ComputePipeline const source = generateMatrixMultiplicationKernel(M, N, K); const computePipeline = device.createComputePipeline({ layout: pipelineLayout, computeStage: { module: device.createShaderModule({ code: source, }), entryPoint: "main" } }); // define the mm function function mm(A, B, C) { // see below } return mm; } ``` With the "run once" boilerplate out of the way we've exposed two more functions to implement. ### Run every time code (where perf matters) The first is a bit more boilerplate: the `mm` function. What we'll do in this function is the work that is required for every invocation. This includes - creating a [CommandEncoder](https://gpuweb.github.io/gpuweb/#command-encoder) to encapsulate our ComputePassEncoder. - creating a [BindGroup](https://gpuweb.github.io/gpuweb/#gpu-bind-group) to bind our runtime memory to the BindGroupLayout we defined earlier. - creating a [ComputePassEncoder](https://gpuweb.github.io/gpuweb/#compute-pass-encoder) to bind the compiled source and BindGroup as well as the dispatch parameters. Instead of immediately invoking the generated command buffer, we will return it. This will free up our ability to chain together command buffers and potentially hide dispatch latency. ``` function mm(A, B, C) { const commandEncoder = device.createCommandEncoder(); const bindGroup = device.createBindGroup({ layout: bindGroupLayout, bindings: [ { binding: 0, resource: { buffer: A, size: M * K * 4 } }, { binding: 1, resource: { buffer: B, size: N * K * 4 } }, { binding: 2, resource: { buffer: C, size: M * N * 4 } }, ] }); const passEncoder = commandEncoder.beginComputePass(); passEncoder.setPipeline(computePipeline); passEncoder.setBindGroup(0, bindGroup); passEncoder.dispatch(M / 8, N / 8, 1); passEncoder.endPass(); return commandEncoder.finish(); } ``` Note that we multiply the number of elements by `4` because the buffer size is in bytes and we are using floats, which are 4 bytes each. The second function we'll need to implement is more interesting: `generateMatrixMultiplicationKernel`. For the sake of simplicity we will hardcode the input sizes. These can always be passed as arguments to the kernel (that'd require more binding code). ### The kernel If you're not familiar with matrix multiplication, the operation is defined as $$ \textbf{A} \textbf{B} = \sum_{k} a_{mk} b_{kn} $$ where $k$ is the reduction axis. <center> <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/e/eb/Matrix_multiplication_diagram_2.svg/1920px-Matrix_multiplication_diagram_2.svg.png" width="400" max-width="80%"/> </center> Computationally, if $k$ is small, we end up with a memory bound operation. If $k$ and $m$ or $n$ are large, we can readily leverage the compute power of our GPU. However, if $k$ is large and $m$ and $n$ are small, we will need to employ some tricks to parallelize the reduction itself. For the sake of this demo, we'll focus on the first two cases. Before writing the kernel, we should implement a CPU version for correctness checks. ``` function mm_ref(A, B, C, M, N, K) { for (let m = 0; m < M; ++m) { for (let n = 0; n < N; ++n) { let res = 0; for (let k = 0; k < K; ++k) { res += A[m * K + k] * B[k * N + n]; } C[m * N + n] = res; } } } ``` And then a *working* kernel. Even totally unoptimized, it is still reasonably fast compared to the CPU code. ``` function generateMatrixMultiplicationKernel(M, N, K) { return `[numthreads(8, 8, 1)] compute void main(device float[] A : register(u0), device float[] B : register(u1), device float[] C : register(u2), float3 threadID : SV_DispatchThreadID) { uint m = uint(threadID.x); uint n = uint(threadID.y); uint c_idx = n + m * ${N}; float result = 0.0; for (uint k = 0; k < ${K}; k++) { uint a_idx = k + m * ${K}; uint b_idx = n + k * ${N}; result += A[a_idx] * B[b_idx]; } C[c_idx] = result; }` } ``` There's a fair amount to observe about the source above. - We're parallelizing over the output `N` and `M` - `numthreads` is directly related to `passEncoder.dispatch` - We're dispatching M * N / 64 groups of 64 threads each - `threadID` refers to the global thread index (rather than local to the group) - Everything is typed - There's `float`, `uint`, and even vector types like `float3` - `threadID`s need to be converted from `float` into `uint`, oddly enough - We are hard coding the sizes of `M`, `N`, and `K` into the kernel - This might allow the kernel compiler to do a better job - This requires us to recompile for every new size ### Putting it together ``` async function run() { const adapter = await navigator.gpu.requestAdapter(); const device = await adapter.requestDevice(); const M = 1024; const N = 1024; const K = 1024; const mm = createMatrixMultiplication(device, M, N, K); const A = randGPU(device, M * K); const B = randGPU(device, K * N); const C = randGPU(device, M * N); const t0 = performance.now(); device.getQueue().submit([mm(A, B, C)]); const result = await toCPU(device, C, M * N); const t1 = performance.now(); // log the gflops achieved const flops = M * N * K * 2; console.log(flops / ((t1 - t0) * 1e6)); } ``` The goal of the above code is to benchmark the performance of our matrix multiplication implementation. `flops` is the total number of floating point operations, `K` multiplications and additions for each one of that `M * N` elements in the output. `performance.now()` has a resolution of milliseconds so to calculate gigaflops (billion floating point operations per second), we multiply it by `1e6`. On my machine I see around 160Gflops. This is quite bad compared to the [M1 GPU specs](https://www.notebookcheck.net/Apple-M1-7-Core-GPU-GPU-Benchmarks-and-Specs.504540.0.html) which claim it can hit nearly 2.3Tflops. We've got a ways to go! Before optimizing, let's go over the two functions `randGPU` and `toCPU` that allocate and copy memory respectively. ### Memory Management ``` function randGPU(device, numel) { const [gpu, cpu] = device.createBufferMapped({ size: numel * 4, // sizeof float usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, }); let rand = new Float32Array(numel); for (let i = 0; i < numel; ++i) { rand[i] = Math.random() / 500; } new Float32Array(cpu).set(rand); gpu.unmap(); return gpu; } ``` We create a mapped [Buffer](https://gpuweb.github.io/gpuweb/#buffer-interface) and populate it with random elements of small size. Calling `gpu.unmap()` releases the ability for JavaScript to populate the buffer and allows us to use it on the GPU itself. Unlike allocating data for the GPU, getting data from the GPU to the CPU is asynchronous. ``` async function toCPU(device, gpu_array, numel) { const buffer = device.createBuffer({ size: numel * 4, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ, }); const commandEncoder = device.createCommandEncoder(); commandEncoder.copyBufferToBuffer(gpu_array, 0, buffer, 0, numel * 4); device.getQueue().submit([commandEncoder.finish()]); return new Float32Array(await buffer.mapReadAsync()); } ``` We first create a read buffer that can be mapped to CPU and encode the command required to copy another GPU buffer (the output of `mm`). After we copy the [buffer to buffer](https://gpuweb.github.io/gpuweb/#dom-gpucommandencoder-copybuffertobuffer) we have to `await` the mapping operation and finally populate a new `Float32Array`. ### Optimizing the kernel Before we can optimize, we should improve the benchmark a bit. Instead of a single invocation, we should leverage the pipelining ability of asynchronous dispatch. ``` const t0 = performance.now(); device.getQueue().submit([ mm(A, B, C), mm(C, B, A), mm(A, C, B), mm(B, A, C), mm(A, B, C), mm(C, B, A), mm(A, C, B), mm(B, A, C), ]); const result = await toCPU(device, C, M * N); const t1 = performance.now(); const flops = M * N * K * 2 * 8; console.log(flops / ((t1 - t0) * 1e6)); ``` We now do 8 back to back matrix multiplications, so we have to update the `flops` calculation accordingly. #### Vectorized fused multiply-accumulate The first step to optimization is to ensure the hottest instructions are also the most efficient. Right now our kernel uses scalar multiplication and scalar addition. We should instead use a fused variant, often referred to as an FMA. WebGPU exposes an approximation of an FMA instruction, `mad`. Since this single instruction does `a * b + c`, it might speed up our computation. (Empirically, I didn't measure any noticeable speedup from swapping to `mad` over `a * b + c`, which may be the compiler helping us out). Another potential issue with performance is that our kernel uses scalars instead of vectors. There are many ways to vectorize our code, but for now we can vectorize the contiguous output dimension `N`. Since matrix `A` doesn't have dimension `N`, we can continue to read it as a scalar `float[]` array. For `B` and output `C`, we will want to swap to `float4[]` arrays. ``` function generateVectorizedMatrixMultiplicationKernel(M, N, K) { return `[numthreads(8, 8, 1)] compute void main(device float[] A : register(u0), device float4[] B : register(u1), device float4[] C : register(u2), float3 threadID : SV_DispatchThreadID) { uint m = uint(threadID.x); uint n = uint(threadID.y); uint c_idx = n + m * ${N/4}; float4 result = float4(0.0, 0.0, 0.0, 0.0); for (uint k = 0; k < ${K}; k++) { uint a_idx = k + m * ${K}; float a_elem = A[a_idx]; float4 a = float4(a_elem, a_elem, a_elem, a_elem); uint b_idx = n + k * ${N/4}; float4 b = B[b_idx]; result = mad(a, b, result); } C[c_idx] = result; }` } ``` We will also need to update the dispatch logic to account for the fact that we are reading and writing 4 elements at a time. ``` function createMatrixMultiplication(device, M, N, K) { // ... function mm(A, B, C) { // ... passEncoder.dispatch(M / 8, N / 8 / 4, 1); // ... } return mm; } ``` We've improved the performance a lot with this small change, going from 160Gflops to over 420Gflops. But we can do better! #### Tiling Instead of a single vector, we can use multiple vectors. ``` function generateTiledMatrixMultiplicationKernel(M, N, K) { return `[numthreads(8, 8, 1)] compute void main(device float4[] A : register(u0), device float4[] B : register(u1), device float4[] C : register(u2), float3 threadID : SV_DispatchThreadID) { uint m = uint(threadID.x); uint n = uint(threadID.y); uint c_idx = (n + m * ${N/4}) * 4; float4 result0 = float4(0.0, 0.0, 0.0, 0.0); float4 result1 = float4(0.0, 0.0, 0.0, 0.0); float4 result2 = float4(0.0, 0.0, 0.0, 0.0); float4 result3 = float4(0.0, 0.0, 0.0, 0.0); for (uint k = 0; k < ${K}; k++) { uint a_idx = k + m * ${K}; float4 a = A[a_idx]; float4 a_x = float4(a.x, a.x, a.x, a.x); float4 a_y = float4(a.y, a.y, a.y, a.y); float4 a_z = float4(a.z, a.z, a.z, a.z); float4 a_w = float4(a.w, a.w, a.w, a.w); uint b_idx = n + k * ${N/4}; float4 b = B[b_idx]; result0 = mad(a_x, b, result0); result1 = mad(a_y, b, result1); result2 = mad(a_z, b, result2); result3 = mad(a_w, b, result3); } C[c_idx + 0] = result0; C[c_idx + 1] = result1; C[c_idx + 2] = result2; C[c_idx + 3] = result3; }` } ``` That brings us to around 680Gflops! #### More to go There's more we can do, such as increasing the tile size and unrolling the reduction loop. I'll leave these optimizations as an exercise to the reader. These types of improvements might also be well handled by generating the source instead of hand-tuning it. In fact, there exist compilers, [like TVM](https://tvm.apache.org/2020/05/14/compiling-machine-learning-to-webassembly-and-webgpu), that can tune all of this stuff for us and have been hitting performance much closer to native code. It's great that the web is starting to adopt native GPU support and it's also really pleasant to be able to quickly hit good performance.