Processing math: 100%

How to Get 1.5 TFlops of FP32 Performance on a Single M1 CPU Core

by @bwasti (mastodon)


If you're in the market for training large modern neural networks, this post won't really be relevant, since that's 100x slower than an A100 (156TFlops).

So, how on earth is 1.5 TFlops interesting?

We are not in the realm of beefy accelerators or GPU tensor cores. We are talking about real-world linear algebra performance that lives one cycle away from CPU registers.

Weirdly, Apple has been hiding this from us! In this post we'll be going through some code to lift that curtain. All the code uses the aarch.h header in corsix's awesome repo: https://github.com/corsix/amx

What is an AMX Co-Processor?

It's basically SIMD on steroids. An important distinction is that the AMX:CPU ratio is not 1:1; not every core has its own AMX co-processor.

Here are the sizes one might use to load or store values:

The minimum is as wide as a full AVX512 register.

But, where do these values get loaded or stored from? Clearly such sizes would use up the entire neon register file pretty quick. Well, there's a separate register file just for AMX and it's kinda weird.

The registers are segmented into groups: X, Y and Z. For every instruction, the X and Y groups hold inputs and the Z group holds outputs.

As we can see, X and Y are pretty big! A full KB between them. But Z takes the cake and then some:

(Spoiler: a full 1024 bytes (1/4 of the Z registers) can be populated with a single AMX instruction.)

So how do we get from X and Y to Z? Well the number of ways is so large that it doesn't fit cleanly in the ISA encoding. So, Apple decided to encode most of the instruction information in a general purpose register. This turns out to be quite cool to work with, because it allows runtime (on-the-fly) configuration of the code being executed on AMX.

The goal of this post is to simply use the co-processor as efficiently as possible. There are vector-vector instructions that will output vectors of the same length, but those don't come anywhere close to saturating the compute capabilities of this chip. Instead, we'll have to use an outer product to really get things going.

What is an outer product? Assuming you have two input vectors u and v:

u=[u1u2um],v=[v1v2vn]

The outer product is the matrix containing a product of every possible pairwise combination of their elements. (This gives some hints as to why the Z register group is so much bigger than X and Y.)

uv=[u1v1u1vnu2v1u2vnumv1umvn]

On the AMX chip, this boils down to a very simple instruction that looks a lot like this:

And there's a flag you can set to also make it accumulate from the previous result:

With this, we have all we need to write a matrix multiplication: repeatedly load 16 floats from our input matrices and accumulate their outer products into a 16x16 output. The reduction dimension K doesn't even matter!

Let's simplify the problem and implicitly transpose the matrix multiplication. Both A and B (our inputs) will have K (our reduction dimension) as the leading dimension. This doesn't really matter much in practice, but it simplifies our code a lot.

Here's a reference that we can use to check our proposed solution:

void reference_16x16xK(float *A, float *B, float *C, uint64_t K) {
  for (uint32_t m = 0; m < 16; ++m) {
    for (uint32_t n = 0; n < 16; ++n) {
      C[n * 16 + m] = 0;
      for (uint32_t k = 0; k < K; ++k) {
        C[n * 16 + m] += A[k * 16 + m] * B[k * 16 + n];
      }
    }
  }
}

And here is how we might do it in AMX:

// only set for k == 0
uint64_t reset_z = 1ull << 27;

for (uint32_t k = 0; k < K; ++k) {
  uint64_t idx = k % 4;
  // 64 bytes = 16 floats
  AMX_LDX((uint64_t)A + k * 64);
  AMX_LDY((uint64_t)B + k * 64);
  
  // now we do 4 indepedent outer products (avoiding pipeline hazards)
  AMX_FMA32(reset_z);
  reset_z = 0;
}

Interestingly, we didn't address any of the registers. Secretly we actually did. In the same way reset_z is encoded as a bit mask, register addresses are also encoded in the arguments passed to AMX_*. The pointers to A and B only ever use up to 56 bits so Apple engineers stashed information in the other 8. We just set them all to 0 by accident. So, we're using registers "0" for X and Y in this case.

The code to store the Z registers to memory is a bit more complicated because we only populated the first column. That means we need to grab registers 0, 4, 8, etc:

for (uint64_t i = 0; i < 16; ++i) {
  const uint64_t z_register = (i * 4ull) << 56;
  AMX_STZ(z_register | (uint64_t)C + i * 64);
}

Unfortunately, when you run the code above you'll find that it's super slow. A paltry couple hundred GFlops. Why? Two reasons.

The first slowdown is a pipeline hazard. Every AMX_FMA32 is dependent on the previous because we accumulate into a single subset of the register file. We end up hitting 25% of the register file full throttle and leave the rest to idle, preventing instruction level parallelism.

The next issue is that we're loading from memory inefficiently. We've got the ability to load 128 bytes at once, yet the code above only loads 64 bytes. Similarly, we can kick off loads to other registers instead of loading to the same ones each time. This enables a bit of instruction level parallelism as well.

So what's the plan?

We're going to load 128 bytes to X and Y and then calculate a 32x32 block. This will involve 4 indepdenent calculations of 16x16 blocks, which should induce instruction level parallelism as well as utilize the loaded memory more efficiently (each 64-byte register is used twice).

void mm32x32xK(float* A, float* B, float* C, uint64_t K) {

  // flag to load/store 128 bytes
  const uint64_t load_store_2 = 1ull << 62;
  const uint64_t load_store_width = 128; // in bytes

  // only set for k == 0
  uint64_t reset_z = 1ull << 27;

  for (uint32_t k = 0; k < K; ++k) {
    uint64_t idx = k % 4;
    // load to X, Y (skipping every other index because we're loading 128 bytes)
    AMX_LDX(load_store_2 | (idx * 2) << 56 | (uint64_t)A + k * load_store_width);
    AMX_LDY(load_store_2 | (idx * 2) << 56 | (uint64_t)B + k * load_store_width);

    // offset into X and Y registers is byte-wise
    const uint64_t offset = idx * load_store_width;

    // now we do 4 indepedent outer products (avoiding pipeline hazards)
    AMX_FMA32(reset_z | (0ull << 20) | ((offset +  0ull) << 10) | ((offset +  0ull) << 0));
    AMX_FMA32(reset_z | (1ull << 20) | ((offset + 64ull) << 10) | ((offset +  0ull) << 0));
    AMX_FMA32(reset_z | (2ull << 20) | ((offset +  0ull) << 10) | ((offset + 64ull) << 0));
    AMX_FMA32(reset_z | (3ull << 20) | ((offset + 64ull) << 10) | ((offset + 64ull) << 0));
    reset_z = 0;
  }

  for (uint64_t i = 0; i < 16; ++i) {
    // store interleaved
    AMX_STZ(load_store_2 | ((i * 4ull + 0) << 56) | (uint64_t)C + i * load_store_width);
    AMX_STZ(load_store_2 | ((i * 4ull + 2) << 56) | (uint64_t)C + (16 + i) * load_store_width);
  }
}

I put comments above, but there are some interesting details related to the flags used for the instructions. Corsix has done a great job of explaining this, so I'll leave the links here:

So how fast did we get it? Well, it kinda depends on K, but we hit 1.5TFlops for larger values 😁

It should be no surprise that larger problems get better relative performance since there's more opportunity for the cache to warm up and the CPU to interleave instructions.

Overall, these problem sizes are microscopic in the context of the large modern neural networks chasing general AI. However, this type of performance opens the door for smaller neural networks to find their place in modern real-world computation. If a prediction can run on a battery powered laptop in a couple dozen nanoseconds, there's likely a lot of value than can be added to places that might otherwise have used heuristics. What do you think?

Thanks for reading!


If you're interested in following my work (I post mostly about machine learning, performance and a love of {Java,Type}Script), I use Twitter and Mastodon!