# How to Get 1.5 TFlops of FP32 Performance on a Single M1 CPU Core

****

If you're in the market for training large modern neural networks,
this post won't really be relevant,
since that's 100x slower than an A100 (156TFlops).

So, how on earth is 1.5 TFlops interesting?

- this is running on a single core of a battery powered 2020 MacBook Air
- this is running with a ~0.5 *nanosecond* latency per instruction

We are not in the realm of beefy accelerators or GPU tensor cores.
We are talking about real-world linear algebra performance that
lives **one cycle** away from CPU registers.

Weirdly, Apple has been hiding this from us!
In this post we'll be going through some code to lift that curtain.
All the code uses the aarch.h header in corsix's
awesome repo: https://github.com/corsix/amx

## What is an AMX Co-Processor?

It's basically SIMD on steroids.  An important distinction is
that the AMX:CPU ratio is not 1:1; not every core has its own AMX co-processor.

Here are the sizes one might use to load or store values:
![](https://i.imgur.com/3gimUQ7.png)

The *minimum* is as wide as a full AVX512 register.

But, where do these values get loaded or stored from?
Clearly such sizes would use up the entire neon register file pretty quick.
Well, there's a separate register file just for AMX and it's kinda weird.

The registers are segmented into groups: X, Y and Z.
For every instruction, the X and Y groups hold inputs and the Z group
holds outputs.

![](https://i.imgur.com/PUTfqIY.png)

As we can see, X and Y are pretty big! A full KB between them.
But Z takes the cake and then some:
![](https://i.imgur.com/Xqtu1xG.png)

(*Spoiler: a full 1024 bytes (1/4 of the Z registers) can be populated
with a single AMX instruction.*)

So how do we get from X and Y to Z?
Well the number of ways is so large that it doesn't fit cleanly
in the ISA encoding.  So, Apple decided to encode most of the
instruction information in a general purpose register.
This turns out to be quite cool to work with, because it allows
*runtime* (on-the-fly) configuration of the code being executed on AMX.

The goal of this post is to simply use the co-processor as efficiently as possible.
There are vector-vector instructions that will output vectors of the same length,
but those don't come anywhere close to saturating the compute capabilities of this chip.
Instead, we'll have to use an outer product to really get things going.

What is an outer product?
Assuming you have two input vectors $\mathbf{u}$ and $\mathbf{v}$:

$$\mathbf{u} = \begin{bmatrix} u_1 \\\\ u_2 \\\\ \vdots \\\\ u_m \end{bmatrix}, \quad \mathbf{v} = \begin{bmatrix} v_1 \\\\ v_2 \\\\ \vdots \\\\ v_n \end{bmatrix}$$

The outer product is the matrix containing a product of
every possible pairwise combination of their elements.
(This gives some hints as to why the Z register group
is so much bigger than X and Y.)

$$\mathbf{u} \otimes \mathbf{v} = \begin{bmatrix} u_1v_1 & \dots & u_1v_n \\\\ u_2v_1 & \dots & u_2v_n \\\\ \vdots & \ddots & \vdots \\\\ u_mv_1 & \dots & u_mv_n \end{bmatrix}$$

On the AMX chip, this boils down to a very simple instruction
that looks a lot like this:

![](https://i.imgur.com/yQa4cdq.png)

And there's a flag you can set to also make it accumulate from the
previous result:

![](https://i.imgur.com/MPsmwnX.png)

With this, we have all we need to write a matrix multiplication:
repeatedly load 16 floats from our input matrices and accumulate their
outer products into a 16x16 output.  The reduction dimension K doesn't even matter!

Let's simplify the problem and implicitly transpose the matrix multiplication.
Both A and B (our inputs) will have K (our reduction dimension)
as the *leading* dimension.  This doesn't really matter much in practice,
but it simplifies our code a lot.

Here's a reference that we can use to check our proposed solution:

cpp
void reference_16x16xK(float *A, float *B, float *C, uint64_t K) {
for (uint32_t m = 0; m < 16; ++m) {
for (uint32_t n = 0; n < 16; ++n) {
C[n * 16 + m] = 0;
for (uint32_t k = 0; k < K; ++k) {
C[n * 16 + m] += A[k * 16 + m] * B[k * 16 + n];
}
}
}
}


And here is how we might do it in AMX:

cpp
// only set for k == 0
uint64_t reset_z = 1ull << 27;

for (uint32_t k = 0; k < K; ++k) {
uint64_t idx = k % 4;
// 64 bytes = 16 floats
AMX_LDX((uint64_t)A + k * 64);
AMX_LDY((uint64_t)B + k * 64);

// now we do 4 indepedent outer products (avoiding pipeline hazards)
AMX_FMA32(reset_z);
reset_z = 0;
}


Interestingly, we didn't address *any* of the registers.
Secretly we actually did.
In the same way reset_z is encoded as a bit mask,
register addresses are *also* encoded in the arguments
passed to AMX_*.
The pointers to A and B only ever use up to 56 bits
so Apple engineers stashed information in the other 8.
We just set them all to 0 by accident.
So, we're using registers "0" for X and Y in this case.

The code to store the Z registers
to memory is a bit more complicated because we
only populated the first column.
That means we need to grab registers 0, 4, 8, etc:

cpp
for (uint64_t i = 0; i < 16; ++i) {
const uint64_t z_register = (i * 4ull) << 56;
AMX_STZ(z_register | (uint64_t)C + i * 64);
}


Unfortunately, when you run the code above you'll find that it's super slow.
A paltry couple hundred GFlops.
Why?
Two reasons.

The first slowdown is a pipeline hazard.
Every AMX_FMA32 is dependent on the previous because we accumulate into
a single subset of the register file.
We end up hitting 25% of the register file full throttle
and leave the rest to idle, preventing instruction level parallelism.

We've got the ability to load 128 bytes at once,
yet the code above only loads 64 bytes.
This enables a bit of instruction level parallelism as well.

So what's the plan?

![](https://i.imgur.com/4Aqubrs.gif)

We're going to load 128 bytes to X and Y
and then calculate a 32x32 block.
This will involve 4 indepdenent calculations of 16x16 blocks,
which should induce instruction level parallelism
as well as utilize the loaded memory more efficiently (each 64-byte register
is used twice).

cpp
void mm32x32xK(float* A, float* B, float* C, uint64_t K) {

// flag to load/store 128 bytes
const uint64_t load_store_2 = 1ull << 62;
const uint64_t load_store_width = 128; // in bytes

// only set for k == 0
uint64_t reset_z = 1ull << 27;

for (uint32_t k = 0; k < K; ++k) {
uint64_t idx = k % 4;
AMX_LDX(load_store_2 | (idx * 2) << 56 | (uint64_t)A + k * load_store_width);
AMX_LDY(load_store_2 | (idx * 2) << 56 | (uint64_t)B + k * load_store_width);

// offset into X and Y registers is byte-wise
const uint64_t offset = idx * load_store_width;

// now we do 4 indepedent outer products (avoiding pipeline hazards)
AMX_FMA32(reset_z | (0ull << 20) | ((offset +  0ull) << 10) | ((offset +  0ull) << 0));
AMX_FMA32(reset_z | (1ull << 20) | ((offset + 64ull) << 10) | ((offset +  0ull) << 0));
AMX_FMA32(reset_z | (2ull << 20) | ((offset +  0ull) << 10) | ((offset + 64ull) << 0));
AMX_FMA32(reset_z | (3ull << 20) | ((offset + 64ull) << 10) | ((offset + 64ull) << 0));
reset_z = 0;
}

for (uint64_t i = 0; i < 16; ++i) {
// store interleaved
AMX_STZ(load_store_2 | ((i * 4ull + 0) << 56) | (uint64_t)C + i * load_store_width);
AMX_STZ(load_store_2 | ((i * 4ull + 2) << 56) | (uint64_t)C + (16 + i) * load_store_width);
}
}


I put comments above, but there are some interesting details
related to the flags used for the instructions.
Corsix has done a great job of explaining this, so I'll leave the links here:
- load and store flags: https://github.com/corsix/amx/blob/main/ldst.md
- FMA flags: https://github.com/corsix/amx/blob/main/fma.md

So how fast did we get it?  Well, it kinda depends on K,
but we hit 1.5TFlops for larger values 😁

![](https://i.imgur.com/paN25bd.png)

It should be no surprise that larger problems get better
relative performance since
there's more opportunity for the cache to warm up and the CPU to
interleave instructions.

Overall, these problem sizes are microscopic in the context of
the large modern neural networks chasing general AI.
However, this type of performance opens the door for
smaller neural networks to find their place
in modern real-world computation.
If a prediction can run on a battery powered
laptop in a couple dozen nanoseconds, there's likely a lot
of value than can be added to places
that might otherwise have used heuristics.  What do you think?

I use [Twitter](https://twitter.com/bwasti) and [Mastodon](https://sigmoid.social/@bwasti)!