Processing math: 100%

WebAssembly Techniques to Speed Up Matrix Multiplication by 120x

by @bwasti


This post is going to use wasmblr to implement matrix multiplication in pure WebAssembly. Then we'll optimize it until it's comparable to TensorFlow.js. The result is a ~120x speedup over a Javascript implementation that can process ~45 billion elements per second on an M1 chip.

The best thing about WebAssembly is that we'll be able to run all the code in browser! Here's the full code listing.

Matrix Multiplication

Below is a nice visualization of 3x3x3 matrix multiplication:


(source)

The rows of the first matrix are "dotted" with the columns of the second matrix for every possible pairwise combination.

In the above example, each dot product is ~6 operations (multiply, add) and we're doing 9 pairwise dot products, so the total number of operations is 54.

For this post we'll be focusing on matrices that result in 4.2M, 33.6M and 268M operations per matrix multiplication.

Baseline

We'll start with this implementation:

for (let m = 0; m < M; ++m) {
  for (let n = 0; n < N; ++n) {
    for (let k = 0; k < K; ++k) {
      c[m * N + n] += a[m * K + k] * b[k * N + n];
    }
  }
}

If we're using Float32Arrays, this can process around 380,000,000 elements per second on my M1 MacBook. It's typical to measure performance in this way (rather than "runs per second"), because it's invariant to the size of the matrices involved. Another convention is to use a standard unit like "GFlops" (billion floating point operations per second) and drop all the zeros.

The above implementation achieves 0.38GFlops on my machine. Let's make it 120 times faster.

Implementation

To do so, we'll implement matrix multiplication in WebAssembly. I'm going to use wasmblr because I like C++, but any in-browser assembler will work.

Why use an assembler instead of emscripten? One reason is so we can sweep many different optimization variants rather than try to guess the best parameters. It turns out that we'll need to tune independently for Firefox and Chrome. Pre-compiled solutions like emscripten aren't ideal for such iteration as they end up blowing up the codesize.

The code for this section is here. If you'd like to skip it and just see the optimized code, that's here.

Memory [code]

Before we write any computation code, we're going to preallocate memory for inputs and outputs. In a particularly advanced implementation, this might involve allocating scratch space as well, but we're not going to do that.

WebAssembly deals with pages of size 64KiB. Let's calculate the general number of pages we need:

auto pages = (M * N + K * N + M * K) * 4 / (1 << 16) + 1;
memory(pages).export_("mem");

Instead of messy pointers as arguments, we're just going to hardcode the offsets for the inputs and output:

auto A_off = 0;
auto B_off = M * K * 4;
auto C_off = (M * K + K * N) * 4;

Now we can just export the memory and let the user write their arrays directly to the heap:

const mem = instance.exports.mem;
const a = new Float32Array(mem.buffer, 0, M * K);
const b = new Float32Array(mem.buffer, M * K * 4, K * N);
const c = new Float32Array(mem.buffer, (M * K + K * N) * 4, M * N);

Loops [code]

Matrix multiplication is naively O(n3). We'll be sticking with that approach, but curious readers should certainly check out the Strassen algorthm, which offers an algorithmic speedup (at a potentially acceptable memory overhead/numerical instability).

In WebAssembly (which is stack based), a loop might look like this:

auto m = local(i32); // we're going to loop over m

i32.const_(0); // push 0 to the stack
local.set(m); // set m = 0

loop(void_); // start the loop!

  // body goes here
  // ...
  
  local.get(m);  // stack: [m]
  i32.const_(1); // stack: [1, m]
  i32.add();     // stack: [m + 1]
  local.tee(m);  // stack: [m + 1] + update variable "m"
  i32.const_(M); // stack: [M, m + 1]
  i32.lt_u();    // stack: [true/false] (check if m + 1 < M)
  br_if(0);      // if true, jump back to the start of the loop

end(); // end the loop!

We need three of those.

Body [code]

In the body of the loop we want to load from A, B and C. This is typical of matrix multiplication implementations with an α value of 1.

C=αC+AB

Each load operation will look something like this:

// load original value of C
local.get(m);  // stack: [m]
i32.const_(N); // stack: [N, m]
i32.mul();     // stack: [m * N]
local.get(n);  // stack: [n, m * N]
i32.add();     // stack: [m * N + n]
i32.const_(4); // (size of a floating point number)
i32.mul();     // stack: [(m * N + n) * 4]
f32.load(0, C_off);

(and do the same for A and B)

Now we can invoke the actual operation. Since A, B and C are on the stack in the right order, we simply call mul and then add.

// stack: [B, A, C]
f32.mul(); // stack: [B * A, C]
f32.add(); // stack: [B * A + C]
auto c = local(f32);
local.set(c)

Note that we have to save C to a local variable in order to later store it (WebAssembly's stack and locals are a bit messy this way).

The store operation looks a lot like the load operation:

// store new value to C      
local.get(m);
i32.const_(N);
i32.mul();
local.get(n);
i32.add();
i32.const_(4);
i32.mul();
local.get(c);
f32.store(0, C_off);

The result of all this hard work?

Firefox

N=128 (wasmblr): 0.57 gflops

Chrome

N=128 (wasmblr): 0.59 gflops

Nearly 2x faster out of the box! Great.

Optimization [code]

We can do better, but we'll need to pull out a couple of non-obvious techniques.

Vectorization

The first thing we can do is start vectorizing the multiplication. Let's vectorize the N dimension. That means we are loading 1 element from A (from dimension M) and 4 elements from B and C.

In pseudo-Javascript that's

for (let m = 0; m < M; ++m) {
  for (let n = 0; n < N; n += 4) {
    for (let k = 0; k < K; ++k) {
      // splat converts a scalar to a vector
      A_vec = splat4(A[m * K + k]);
      B_vec = B[k * N + n];
      C_vec = C[m * N + n];
      tmp_vec = vec4_mul(A_vec, B_vec);
      C_vec = vec4_add(tmp_vec, C_vec);
      C[m * N + n] = C_vec;
    }
  }
} 

Here are the corresponding wasmblr calls:

v128.load32_splat(0, A_off);
v128.load(0, B_off);
v128.load(0, C_off);
// ...
v128.f32x4_mul();
v128.f32x4_add();
// ...
v128.store(0, C_off);

Unrolling

Unrolling is a simple idea. For example, the code below

for (let n = 0; n < 4; ++n {
  blah(n);
}

would become

blah(n);
blah(n + 1);
blah(n + 2);
blah(n + 3);

This will increase the straight line execution of the code. This is important in WebAssembly applications because loop book-keeping (the code to deal with the loop iteration variable) takes a fair number of instructions to execute.

Local Variables

This technique aims to increase the arithmetic intensity of the inner-most loop. Typically, this is done by loading values from memory into registers. WebAssembly doesn't have registers, so we'll be using local variables and crossing our fingers that the browser's JIT figures things out for us.

Arithmetic intensity (I) refers to the number of arithmetic operations we can perform per load. We'll want to keep the CPU busy while we wait on load instructions. Memory access is really slow, but luckily it can happen in the background.

Matrix multiplication involves doing every pairwise dot product along the K dimension. That means for every m loads of A and n loads of B we can do mn multiplications. We can think of those parameters in terms of arithmetic intensity like this:

Imnm+n

If we get the m and n high enough, we'll never need to worry about keeping the CPU busy. We're only limited by how many values we can reasonably keep in local variables!

(Aside: if you're curious how "busy" the CPU can get in the world of WebAssembly, you can benchmark varying levels of unrolled independent multiplications. Some numbers can be collected in your browser with this example.)

To concretize the ideas above, here's a pseudo-implementation (where we'll assume every loop is actually completely unrolled).

// load into localA, O(k_unroll * m_unroll)
for (let k = 0; k < k_unroll; ++k) {
  for (let m = 0; m < m_unroll; ++m) {
    localA[m * k_unroll + k] = A[base_A + m * K + k];
  }
}

// load into localB, O(k_unroll * n_unroll)
for (let k = 0; k < k_unroll; ++k) {
  for (let n = 0; n < n_unroll; ++n) {
    localB[n * k_unroll + k] = B[base_B + k * N + n];
  }
}

// compute C, O(m_unroll * k_unroll * n_unroll)
for (let k = 0; k < k_unroll; ++k) {
  for (let m = 0; m < m_unroll; ++m) {
    for (let n = 0; n < n_unroll; ++n) {
      const tmp = localA[m * k_unroll + k] * localB[n * k_unroll + k];
      localC[m * k_unroll + n] += tmp;
    }
  }
}

This will work as a sub-program for most implementations assuming correctly calculated base_A and base_B offsets into the global memory.

In the real code, this will involve creating many local variables.

std::vector<int> load_a;
std::vector<int> load_b;
for (auto j = 0; j < K_unroll; ++j) {
  for (auto i = 0; i < M_unroll; ++i) {
    load_a.emplace_back(local(v128));
  }
  for (auto i = 0; i < N_unroll; ++i) {
    load_b.emplace_back(local(v128));
  }
}

The actual loading process involves local.seting all the v128s we pulled from memory:

for (auto k_unroll = 0; k_unroll < K_unroll; ++k_unroll) {
  for (auto m_unroll = 0; m_unroll < M_unroll; ++m_unroll) {
    local.get(a_off);
    v128.load32_splat(0, A_off + (m_unroll * K + k_unroll) * 4);
    local.set(load_a.at(m_unroll * K_unroll + k_unroll));
  }
}

Note that the above code is unrolling things. We are looping through C++ constructs and emitting WebAssembly. This type of "meta-programming" is particularly useful when writing optimized code.

The WebAssembly ends up looking like this:

local.get $var99
v32x4.load_splat offset=512 align=1
local.set $var5
local.get $var99
v32x4.load_splat offset=1024 align=1
local.set $var7
local.get $var99
v32x4.load_splat offset=1536 align=1
local.set $var9
local.get $var99
v32x4.load_splat offset=2048 align=1
local.set $var11

The same sort of thing (without splatting) should be done for B and C.

Tuning

Finally, we're going to lazily find good parameters for the number of local variables and amount of unrolling by tuning everything.

for (let m of [1, 2, 4, 8, 16, 32]) {
  for (let n of [1, 2, 4, 8, 16, 32]) {
    for (let k of [1, 2, 4, 8, 16, 32]) {
      let gflops = await bench(mod, M, N, K, m, n, k);
    }
  }
}

This approach works pretty well because Firefox and Chrome end up tuning to different configurations despite implementing the exact same specification. We've discovered properties of their JIT implementations that would have been hard to reason about by looking through the code.

Results

The result of these optimizations is ~150 lines of C++ code and ~50 lines of tuning Javascript.

In order to get a sense of how good a job we've done, we can compare our performance with this benchmark of TensorFlow.js, a heavily optimized neural network library. This comparison isn't apples-to-apples because TF.js doesn't have pre-allocated outputs, but it gives us a good sense of how well we've done.

Firefox:

N=128 (tfjs-wasm): 9.99 gflops
N=256 (tfjs-wasm): 29.43 gflops
N=512 (tfjs-wasm): 31.47 gflops
N=128 (wasmblr): 43.95 gflops (unroll m: 2, n: 4, k: 16)
N=256 (wasmblr): 43.47 gflops (unroll m: 2, n: 4, k: 16)
N=512 (wasmblr): 43.47 gflops (unroll m: 2, n: 4, k: 8)

Chrome:

N=128 (tfjs-wasm): 29.54 gflops
N=256 (tfjs-wasm): 40.38 gflops
N=512 (tfjs-wasm): 44.03 gflops
N=128 (wasmblr): 46.14 gflops (unroll m: 2, n: 8, k: 1)
N=256 (wasmblr): 45.56 gflops (unroll m: 2, n: 8, k: 1)
N=512 (wasmblr): 45.98 gflops (unroll m: 2, n: 8, k: 1)

Nice!

There are other optimizations worth exploring, such as tiling chunks of the input matrices directly into scratch space (rather than local variables) or tuning loop orders.

We might also want to explore different unrolling parameters and resultant "tail" code when the unrolling doesn't evenly divide the input size. This might let us utilize the optimal number of local variables for each JIT.

I've left these ideas as an exercise to the reader. :^}

Thanks for reading!

Some discussion can be found here: https://news.ycombinator.com/item?id=30073186