This is a quick overview of how to write a matrix multiplication for Safari leveraging the WebGPU API. This will run on both Macs and iPhones provided WebGPU is enabled.
The benchmarks in this document are done on an M1 chip.
The full code can be found here and a demo here.
First, we'll need an async
function
to set up our GPU. This is because the API is largely asynchronous
and many core methods return promises we'll have to await
.
async function run() {
const adapter = await navigator.gpu.requestAdapter();
const device = await adapter.requestDevice();
const M = 1024;
const N = 1024;
const K = 1024;
const mm = createMatrixMultiplication(device, M, N, K);
}
window.addEventListener('load', run);
We set up the adapter
and a device.
Then we scaffold out the invocation of a function that will create
an mm
function.
WebGPU requires some setup that we will want to run only once. The boilerplate includes:
function createMatrixMultiplication(device, M, N, K) {
// BindGroupLayout
const visibility = GPUShaderStage.COMPUTE;
const type = "storage-buffer";
const bindGroupLayout = device.createBindGroupLayout({
bindings: [
{ binding: 0, visibility: visibility, type: type },
{ binding: 1, visibility: visibility, type: type },
{ binding: 2, visibility: visibility, type: type },
]
});
// PipelineLayout
const pipelineLayout = device.createPipelineLayout({
bindGroupLayouts: [bindGroupLayout],
});
// ComputePipeline
const source = generateMatrixMultiplicationKernel(M, N, K);
const computePipeline = device.createComputePipeline({
layout: pipelineLayout,
computeStage: {
module: device.createShaderModule({
code: source,
}),
entryPoint: "main"
}
});
// define the mm function
function mm(A, B, C) {
// see below
}
return mm;
}
With the "run once" boilerplate out of the way we've exposed two more functions to implement.
The first is a bit more boilerplate: the mm
function.
What we'll do in this function is the work that is required for every invocation. This includes
Instead of immediately invoking the generated command buffer, we will return it. This will free up our ability to chain together command buffers and potentially hide dispatch latency.
function mm(A, B, C) {
const commandEncoder = device.createCommandEncoder();
const bindGroup = device.createBindGroup({
layout: bindGroupLayout,
bindings: [
{ binding: 0, resource: { buffer: A, size: M * K * 4 } },
{ binding: 1, resource: { buffer: B, size: N * K * 4 } },
{ binding: 2, resource: { buffer: C, size: M * N * 4 } },
]
});
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(computePipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatch(M / 8, N / 8, 1);
passEncoder.endPass();
return commandEncoder.finish();
}
Note that we multiply the number of elements by 4
because the buffer
size is in bytes and we are using floats, which are 4 bytes each.
The second function we'll need to implement is more interesting:
generateMatrixMultiplicationKernel
.
For the sake of simplicity we will hardcode the input sizes.
These can always be passed as arguments to the kernel (that'd require more
binding code).
If you're not familiar with matrix multiplication, the operation is defined as
AB=∑kamkbkn
where k is the reduction axis.
Computationally, if k is small, we end up with a memory bound operation. If k and m or n are large, we can readily leverage the compute power of our GPU. However, if k is large and m and n are small, we will need to employ some tricks to parallelize the reduction itself. For the sake of this demo, we'll focus on the first two cases.
Before writing the kernel, we should implement a CPU version for correctness checks.
function mm_ref(A, B, C, M, N, K) {
for (let m = 0; m < M; ++m) {
for (let n = 0; n < N; ++n) {
let res = 0;
for (let k = 0; k < K; ++k) {
res += A[m * K + k] * B[k * N + n];
}
C[m * N + n] = res;
}
}
}
And then a working kernel. Even totally unoptimized, it is still reasonably fast compared to the CPU code.
function generateMatrixMultiplicationKernel(M, N, K) {
return `[numthreads(8, 8, 1)]
compute void main(device float[] A : register(u0),
device float[] B : register(u1),
device float[] C : register(u2),
float3 threadID : SV_DispatchThreadID) {
uint m = uint(threadID.x);
uint n = uint(threadID.y);
uint c_idx = n + m * ${N};
float result = 0.0;
for (uint k = 0; k < ${K}; k++) {
uint a_idx = k + m * ${K};
uint b_idx = n + k * ${N};
result += A[a_idx] * B[b_idx];
}
C[c_idx] = result;
}`
}
There's a fair amount to observe about the source above.
N
and M
numthreads
is directly related to passEncoder.dispatch
threadID
refers to the global thread index (rather than local to the group)float
, uint
, and even vector types like float3
threadID
s need to be converted from float
into uint
, oddly enoughM
, N
, and K
into the kernelasync function run() {
const adapter = await navigator.gpu.requestAdapter();
const device = await adapter.requestDevice();
const M = 1024;
const N = 1024;
const K = 1024;
const mm = createMatrixMultiplication(device, M, N, K);
const A = randGPU(device, M * K);
const B = randGPU(device, K * N);
const C = randGPU(device, M * N);
const t0 = performance.now();
device.getQueue().submit([mm(A, B, C)]);
const result = await toCPU(device, C, M * N);
const t1 = performance.now();
// log the gflops achieved
const flops = M * N * K * 2;
console.log(flops / ((t1 - t0) * 1e6));
}
The goal of the above code is to benchmark the performance of our
matrix multiplication implementation.
flops
is the total number of floating point operations, K
multiplications
and additions for each one of that M * N
elements in the output.
performance.now()
has a resolution of milliseconds so to calculate
gigaflops (billion floating point operations per second), we multiply it by 1e6
.
On my machine I see around 160Gflops. This is quite bad compared to the M1 GPU specs which claim it can hit nearly 2.3Tflops. We've got a ways to go!
Before optimizing, let's go over the two functions randGPU
and toCPU
that allocate and copy memory respectively.
function randGPU(device, numel) {
const [gpu, cpu] = device.createBufferMapped({
size: numel * 4, // sizeof float
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
});
let rand = new Float32Array(numel);
for (let i = 0; i < numel; ++i) {
rand[i] = Math.random() / 500;
}
new Float32Array(cpu).set(rand);
gpu.unmap();
return gpu;
}
We create a mapped Buffer
and populate it with random elements of small size.
Calling gpu.unmap()
releases the ability for JavaScript to
populate the buffer and allows us to use it on the GPU itself.
Unlike allocating data for the GPU, getting data from the GPU to the CPU is asynchronous.
async function toCPU(device, gpu_array, numel) {
const buffer = device.createBuffer({
size: numel * 4,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
});
const commandEncoder = device.createCommandEncoder();
commandEncoder.copyBufferToBuffer(gpu_array, 0, buffer, 0, numel * 4);
device.getQueue().submit([commandEncoder.finish()]);
return new Float32Array(await buffer.mapReadAsync());
}
We first create a read buffer that can be mapped to CPU
and encode the command required to copy another GPU buffer (the output of mm
).
After we copy the buffer to buffer
we have to await
the mapping operation and finally populate a new Float32Array
.
Before we can optimize, we should improve the benchmark a bit. Instead of a single invocation, we should leverage the pipelining ability of asynchronous dispatch.
const t0 = performance.now();
device.getQueue().submit([
mm(A, B, C),
mm(C, B, A),
mm(A, C, B),
mm(B, A, C),
mm(A, B, C),
mm(C, B, A),
mm(A, C, B),
mm(B, A, C),
]);
const result = await toCPU(device, C, M * N);
const t1 = performance.now();
const flops = M * N * K * 2 * 8;
console.log(flops / ((t1 - t0) * 1e6));
We now do 8 back to back matrix multiplications, so we have to update
the flops
calculation accordingly.
The first step to optimization is to ensure the hottest instructions are also the most efficient. Right now our kernel uses scalar multiplication and scalar addition. We should instead use a fused variant, often referred to as an FMA.
WebGPU exposes an approximation of an FMA instruction, mad
.
Since this single instruction does a * b + c
, it might speed up
our computation.
(Empirically, I didn't measure any noticeable speedup from swapping
to mad
over a * b + c
, which may be the compiler helping us out).
Another potential issue with performance is that our kernel uses
scalars instead of vectors.
There are many ways to vectorize our code,
but for now we can vectorize the contiguous output dimension N
.
Since matrix A
doesn't have dimension N
, we can continue to read it
as a scalar float[]
array. For B
and output C
, we will want to
swap to float4[]
arrays.
function generateVectorizedMatrixMultiplicationKernel(M, N, K) {
return `[numthreads(8, 8, 1)]
compute void main(device float[] A : register(u0),
device float4[] B : register(u1),
device float4[] C : register(u2),
float3 threadID : SV_DispatchThreadID) {
uint m = uint(threadID.x);
uint n = uint(threadID.y);
uint c_idx = n + m * ${N/4};
float4 result = float4(0.0, 0.0, 0.0, 0.0);
for (uint k = 0; k < ${K}; k++) {
uint a_idx = k + m * ${K};
float a_elem = A[a_idx];
float4 a = float4(a_elem, a_elem, a_elem, a_elem);
uint b_idx = n + k * ${N/4};
float4 b = B[b_idx];
result = mad(a, b, result);
}
C[c_idx] = result;
}`
}
We will also need to update the dispatch logic to account for the fact that we are reading and writing 4 elements at a time.
function createMatrixMultiplication(device, M, N, K) {
// ...
function mm(A, B, C) {
// ...
passEncoder.dispatch(M / 8, N / 8 / 4, 1);
// ...
}
return mm;
}
We've improved the performance a lot with this small change, going from 160Gflops to over 420Gflops.
But we can do better!
Instead of a single vector, we can use multiple vectors.
function generateTiledMatrixMultiplicationKernel(M, N, K) {
return `[numthreads(8, 8, 1)]
compute void main(device float4[] A : register(u0),
device float4[] B : register(u1),
device float4[] C : register(u2),
float3 threadID : SV_DispatchThreadID) {
uint m = uint(threadID.x);
uint n = uint(threadID.y);
uint c_idx = (n + m * ${N/4}) * 4;
float4 result0 = float4(0.0, 0.0, 0.0, 0.0);
float4 result1 = float4(0.0, 0.0, 0.0, 0.0);
float4 result2 = float4(0.0, 0.0, 0.0, 0.0);
float4 result3 = float4(0.0, 0.0, 0.0, 0.0);
for (uint k = 0; k < ${K}; k++) {
uint a_idx = k + m * ${K};
float4 a = A[a_idx];
float4 a_x = float4(a.x, a.x, a.x, a.x);
float4 a_y = float4(a.y, a.y, a.y, a.y);
float4 a_z = float4(a.z, a.z, a.z, a.z);
float4 a_w = float4(a.w, a.w, a.w, a.w);
uint b_idx = n + k * ${N/4};
float4 b = B[b_idx];
result0 = mad(a_x, b, result0);
result1 = mad(a_y, b, result1);
result2 = mad(a_z, b, result2);
result3 = mad(a_w, b, result3);
}
C[c_idx + 0] = result0;
C[c_idx + 1] = result1;
C[c_idx + 2] = result2;
C[c_idx + 3] = result3;
}`
}
That brings us to around 680Gflops!
There's more we can do, such as increasing the tile size and unrolling the reduction loop. I'll leave these optimizations as an exercise to the reader.
These types of improvements might also be well handled by generating the source instead of hand-tuning it. In fact, there exist compilers, like TVM, that can tune all of this stuff for us and have been hitting performance much closer to native code.
It's great that the web is starting to adopt native GPU support and it's also really pleasant to be able to quickly hit good performance.