# WebAssembly Techniques to Speed Up Matrix Multiplication by 120x

****

This post is going to use [wasmblr](https://github.com/bwasti/wasmblr)
to implement matrix multiplication
in pure WebAssembly.  Then we'll optimize it until it's comparable to
[TensorFlow.js](https://www.tensorflow.org/js).
The result is a **~120x speedup** over a Javascript implementation
that can process ~45 billion elements per second on an M1 chip.

The best thing about WebAssembly is that we'll be able to run all
the code [in browser](https://bwasti.github.io/wasmblr/matmul/)!
Here's the full [code listing](https://github.com/bwasti/wasmblr/tree/main/matmul_example).

<center>
<img src="https://i.imgur.com/WAb4K0l.png" style="display:inline;width:480px;max-width:80%;"/>
<img src="https://i.imgur.com/cLhu50x.png" style="display:inline;width:480px;max-width:80%;"/>
</center>

## Matrix Multiplication

Below is a nice visualization of 3x3x3 matrix multiplication:
<center>
<img src="https://www.mscroggs.co.uk/img/full/multiply_matrices.gif" style="width:480px;max-width:80%;"/>
<br>
<a target="_blank" href="https://www.mscroggs.co.uk/blog/tags/matrix%20multiplication">(source)</a>
</center>

The rows of the first matrix are ["dotted"](https://en.wikipedia.org/wiki/Dot_product) with the columns of the second matrix
for every possible pairwise combination.

In the above example, each dot product is ~6 operations (multiply, add)
and we're doing 9 pairwise dot products,
so the total number of operations is 54.

For this post we'll be focusing on matrices that result in
4.2M, 33.6M and 268M operations per matrix multiplication.

## Baseline


for (let m = 0; m < M; ++m) {
for (let n = 0; n < N; ++n) {
for (let k = 0; k < K; ++k) {
c[m * N + n] += a[m * K + k] * b[k * N + n];
}
}
}

If we're using Float32Arrays, this can process around 380,000,000
elements per second on my M1 MacBook.
It's typical to measure performance in this way (rather than "runs per second"),
because it's invariant to the size of the matrices involved.
Another convention is to use a standard unit like "GFlops" (billion floating point operations per second)
and drop all the zeros.

The above implementation achieves 0.38GFlops on my machine.
Let's make it 120 times faster.

## Implementation

To do so, we'll implement matrix multiplication in WebAssembly.
I'm going to use [wasmblr](https://github.com/bwasti/wasmblr)
because I like C++, but any in-browser assembler will work.

Why use an assembler instead of emscripten?
One reason is so we can sweep many different optimization variants
rather than try to guess the best parameters.  It turns out that we'll
need to tune independently for Firefox and Chrome.
Pre-compiled solutions like emscripten aren't ideal for such
iteration as they end up blowing up the codesize.

The code for this section is [here](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L154-L259).
If you'd like to skip it and just see the optimized code,
that's [here](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L261-L410).

#### Memory [[code]](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L156-L160)

Before we write any computation code, we're going to preallocate memory
for inputs and outputs.  In a particularly advanced implementation,
this might involve allocating scratch space as well, but we're not
going to do that.

WebAssembly deals with pages of size 64KiB.  Let's calculate the general
number of pages we need:

cpp
auto pages = (M * N + K * N + M * K) * 4 / (1 << 16) + 1;
memory(pages).export_("mem");


Instead of messy pointers as arguments, we're just going
to hardcode the offsets for the inputs and output:
cpp
auto A_off = 0;
auto B_off = M * K * 4;
auto C_off = (M * K + K * N) * 4;


Now we can just export the memory and let the user write
their arrays directly to the heap:

javascript
const mem = instance.exports.mem;
const a = new Float32Array(mem.buffer, 0, M * K);
const b = new Float32Array(mem.buffer, M * K * 4, K * N);
const c = new Float32Array(mem.buffer, (M * K + K * N) * 4, M * N);


#### Loops [[code]](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L162-L179)

Matrix multiplication is naively $O(n^3)$.
We'll be sticking with that approach, but curious readers should
certainly check out the [Strassen algorthm](https://en.wikipedia.org/wiki/Strassen_algorithm),
which offers an algorithmic speedup
(at a potentially acceptable memory overhead/numerical instability).

In WebAssembly (which is stack based),
a loop might look like this:

cpp
auto m = local(i32); // we're going to loop over m

i32.const_(0); // push 0 to the stack
local.set(m); // set m = 0

loop(void_); // start the loop!

// body goes here
// ...

local.get(m);  // stack: [m]
i32.const_(1); // stack: [1, m]
i32.add();     // stack: [m + 1]
local.tee(m);  // stack: [m + 1] + update variable "m"
i32.const_(M); // stack: [M, m + 1]
i32.lt_u();    // stack: [true/false] (check if m + 1 < M)
br_if(0);      // if true, jump back to the start of the loop

end(); // end the loop!


We need [three](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L15-L28)
of [those](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L76-L104).

#### Body [[code]](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L181-L225)

In the body of the loop we want to load from A, B and C.
This is typical of matrix multiplication implementations
with an $\alpha$ value of $1$.

$$C' = \alpha C + A \cdot B$$

Each load operation will look something like this:
cpp
// load original value of C
local.get(m);  // stack: [m]
i32.const_(N); // stack: [N, m]
i32.mul();     // stack: [m * N]
local.get(n);  // stack: [n, m * N]
i32.add();     // stack: [m * N + n]
i32.const_(4); // (size of a floating point number)
i32.mul();     // stack: [(m * N + n) * 4]

[(and do the same for $A$ and $B$)](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L30-L58)

Now we can invoke the actual operation.
Since $A$, $B$ and $C$ are on the stack in the right order,
we simply call mul and then add.

cpp
// stack: [B, A, C]
f32.mul(); // stack: [B * A, C]
f32.add(); // stack: [B * A + C]
auto c = local(f32);
local.set(c)

Note that we have to save $C$ to a local variable in order
to later store it (WebAssembly's stack and locals are a bit messy
this way).

The store operation looks a lot like the load operation:

cpp
// store new value to C
local.get(m);
i32.const_(N);
i32.mul();
local.get(n);
i32.const_(4);
i32.mul();
local.get(c);
f32.store(0, C_off);


The result of all this [hard work](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L4-L107)?

Firefox

N=128 (wasmblr): 0.57 gflops


Chrome

N=128 (wasmblr): 0.59 gflops


Nearly 2x faster out of the box! Great.

## Optimization [[code]](https://github.com/bwasti/wasmblr/blob/main/matmul_example/mm.cc#L271-L407)

We can do better, but we'll need to pull out a couple of
non-obvious techniques.

#### Vectorization

The first thing we can do is start vectorizing the multiplication.
Let's vectorize the $N$ dimension.
That means we are loading 1 element from $A$ (from dimension $M$)
and 4 elements from $B$ and $C$.

In pseudo-Javascript that's

javascript
for (let m = 0; m < M; ++m) {
for (let n = 0; n < N; n += 4) {
for (let k = 0; k < K; ++k) {
// splat converts a scalar to a vector
A_vec = splat4(A[m * K + k]);
B_vec = B[k * N + n];
C_vec = C[m * N + n];
tmp_vec = vec4_mul(A_vec, B_vec);
C[m * N + n] = C_vec;
}
}
}

Here are the corresponding wasmblr calls:

cpp
// ...
v128.f32x4_mul();
// ...
v128.store(0, C_off);


#### Unrolling

Unrolling is a simple idea.
For example, the code below


for (let n = 0; n < 4; ++n {
blah(n);
}

would become

blah(n);
blah(n + 1);
blah(n + 2);
blah(n + 3);


This will increase the straight line execution
of the code.
This is important in WebAssembly applications
because loop book-keeping (the code to deal with the loop iteration variable)
takes a fair number of instructions to execute.

#### Local Variables

This technique aims to increase the arithmetic intensity of the
inner-most loop.
WebAssembly doesn't have registers, so we'll be using
local variables and crossing our fingers that the browser's JIT
figures things out for us.

Arithmetic intensity ($I$) refers to the number of arithmetic operations
we can perform per load. We'll want to keep the CPU busy while we wait
on load instructions.  Memory access is really slow,
but luckily it can happen in the background.

Matrix multiplication involves doing *every* pairwise
dot product along the $K$ dimension.  That means
for every $m$ loads of $A$ and $n$ loads of $B$
we can do $m\cdot n$ multiplications.
We can think of those parameters in terms of
arithmetic intensity like this:

$$I \approx \frac{m \cdot n}{m + n}$$

If we get the $m$ and $n$ high enough, we'll never need to worry about
keeping the CPU busy.
We're only limited by how many values we can reasonably keep in local variables!

(*Aside: if you're curious how "busy" the CPU can get
in the world of WebAssembly, you can benchmark
varying levels of unrolled independent multiplications.
Some numbers can be collected in your browser with
this [example](https://bwasti.github.io/wasmblr/flops/).*)

To concretize the ideas above, here's a pseudo-implementation
(where we'll assume every loop is actually completely unrolled).

javascript
// load into localA, O(k_unroll * m_unroll)
for (let k = 0; k < k_unroll; ++k) {
for (let m = 0; m < m_unroll; ++m) {
localA[m * k_unroll + k] = A[base_A + m * K + k];
}
}

// load into localB, O(k_unroll * n_unroll)
for (let k = 0; k < k_unroll; ++k) {
for (let n = 0; n < n_unroll; ++n) {
localB[n * k_unroll + k] = B[base_B + k * N + n];
}
}

// compute C, O(m_unroll * k_unroll * n_unroll)
for (let k = 0; k < k_unroll; ++k) {
for (let m = 0; m < m_unroll; ++m) {
for (let n = 0; n < n_unroll; ++n) {
const tmp = localA[m * k_unroll + k] * localB[n * k_unroll + k];
localC[m * k_unroll + n] += tmp;
}
}
}


This will work as a sub-program for most implementations assuming correctly
calculated base_A and base_B offsets into the global memory.

In the real code, this will involve creating many local variables.

cpp
for (auto j = 0; j < K_unroll; ++j) {
for (auto i = 0; i < M_unroll; ++i) {
}
for (auto i = 0; i < N_unroll; ++i) {
}
}


The actual loading process involves local.seting all the v128s
we pulled from memory:

cpp
for (auto k_unroll = 0; k_unroll < K_unroll; ++k_unroll) {
for (auto m_unroll = 0; m_unroll < M_unroll; ++m_unroll) {
local.get(a_off);
v128.load32_splat(0, A_off + (m_unroll * K + k_unroll) * 4);
}
}

Note that the above code is unrolling things.
We are looping through C++ constructs and emitting WebAssembly.
This type of "meta-programming" is particularly useful when writing optimized
code.

The WebAssembly ends up looking like this:

php
local.get var99 v32x4.load_splat offset=512 align=1 local.setvar5
local.get var99 v32x4.load_splat offset=1024 align=1 local.setvar7
local.get var99 v32x4.load_splat offset=1536 align=1 local.setvar9
local.get var99 v32x4.load_splat offset=2048 align=1 local.setvar11

The same sort of thing (without splatting) should be done for $B$ and $C$.

#### Tuning

Finally, we're going to lazily find good parameters
for the number of local variables and amount of unrolling
by tuning everything.

javascript
for (let m of [1, 2, 4, 8, 16, 32]) {
for (let n of [1, 2, 4, 8, 16, 32]) {
for (let k of [1, 2, 4, 8, 16, 32]) {
let gflops = await bench(mod, M, N, K, m, n, k);
}
}
}


This approach works pretty well because Firefox and Chrome end up tuning to
different configurations despite implementing the exact same specification.
We've discovered properties of their JIT implementations that
would have been hard to reason about by looking through the code.

## Results

The result of these optimizations is ~150 lines of C++ code
and ~50 lines of tuning Javascript.

In order to get a sense of how good a job we've done,
we can compare our performance with [this benchmark of TensorFlow.js](https://codepen.io/bwasti/pen/GRMebrx?editors=0012),
a heavily optimized neural network library.
This comparison isn't apples-to-apples because TF.js
doesn't have pre-allocated outputs, but it gives us a good sense
of how well we've done.

Firefox:

bash
N=128 (tfjs-wasm): 9.99 gflops
N=256 (tfjs-wasm): 29.43 gflops
N=512 (tfjs-wasm): 31.47 gflops
N=128 (wasmblr): 43.95 gflops (unroll m: 2, n: 4, k: 16)
N=256 (wasmblr): 43.47 gflops (unroll m: 2, n: 4, k: 16)
N=512 (wasmblr): 43.47 gflops (unroll m: 2, n: 4, k: 8)


Chrome:
bash
N=128 (tfjs-wasm): 29.54 gflops
N=256 (tfjs-wasm): 40.38 gflops
N=512 (tfjs-wasm): 44.03 gflops
N=128 (wasmblr): 46.14 gflops (unroll m: 2, n: 8, k: 1)
N=256 (wasmblr): 45.56 gflops (unroll m: 2, n: 8, k: 1)
N=512 (wasmblr): 45.98 gflops (unroll m: 2, n: 8, k: 1)


<center>
<img src="https://i.imgur.com/WAb4K0l.png" style="display:inline;width:480px;max-width:80%;"/>
<img src="https://i.imgur.com/cLhu50x.png" style="display:inline;width:480px;max-width:80%;"/>
</center>

Nice!

There are other optimizations worth exploring, such as
tiling chunks of the input matrices directly into scratch space
(rather than local variables)
or tuning loop orders.

We might also want to explore different unrolling parameters
and resultant "tail" code when the unrolling doesn't evenly
divide the input size.  This might let us utilize
the optimal number of local variables for each JIT.

I've left these ideas as an exercise to the reader. :^}