# Sizing Up Neural Nets

****

Neural networks clearly aren't getting any smaller.
Despite that, actually calculating the "size" of one is ambiguous
since there are so many different metrics one could use.
This ends up making cost planning
(on both the hardware and software side) somewhat tricky to discuss.

Below is a review of common metrics
found in the literature that should make it easier to think about neural network costs
There's no "one size fits all" with metrics like this, so it's important
to be informed on the benefits and shortcomings of various approaches.

## Number of Parameters

The most popular metric I've seen
(in papers) is to count the number of parameters being trained.
It's super easy to do!

python
count = 0
for param in parameters:
count += param.numel


Unfortunately, except for rough comparisons between
identical architectures,
it's not very useful for fine-grained cost measurements.
More parameters doesn't mean you need more GPUs.

Consider matrix addition and matrix multiplication
by the same size.  Assuming the right hand ($B$) side of both operations
is a learned matrix of size 1000x1000:

$$C = A + B$$
$$C = A \times B$$

are these comparable in compute cost? Nope.

![](https://i.imgur.com/LDk7xxk.png)

That's a 57x difference in raw compute power used on my device.

## FLOPs

A popular alternative to parameter counts
in the literature is flop counts.
A "flop" is a floating point operation.
On it's face, this looks like a pretty easy
metric to calculate.

python
def foo(A, B):
c = 0
for i in range(100):
c = max(A[i] * B[i], c)


How many flops is f(A, B)? 200.
100 iterations of two floating point operations:
multiply and max.

#### Compilation

But what if we compile the above code using something
like jax.jit?

python
foo = jax.jit(foo)


Well jax.jit is going to make it run *faster* most likely,
but the total number of flops doesn't change at all.
At least, in the case of foo.

Consider bar below:

python
def bar(a):
c = 0
for i in range(100):
c = c + a

bar(a) clearly has 100 flops (additions),
but
python
bar = jax.jit(bar)

will simplify this down into a single multiplication.
So, in reality, we have 1 flop!

Key takeaway, **flops must be calculated after compiler optimization**.

#### Interpretability

So, can we use flops as a metric for model size?
to have the same number of flops,
will they run in the same amount of time? Nope.

![](https://i.imgur.com/NC50xIn.png)

Now, despite 2M flops each,
the matrix multiplication runs 39x faster than the

## FLOPs + Arithmetic Intensity

How do we explain the new (reversed) gap in expected performance? Memory access.

Memory is really really slow.

![](https://i.imgur.com/AVlSgLz.gif)
*(You can play with this fun site [here](https://www.overbyte.com.au/misc/Lesson3/CacheFun.html).)*

Modern GPUs have much faster memory than is reflected in the above GIF,
but it's still no comparison to the speed of in-register computation.

This is a well understood problem with flop counts. A useful additional
metric (common in HPC papers) is arithmetic intensity $I$:

$$I = \frac{FLOPs}{Bytes}$$

Where $Bytes$ quantifies the total memory traffic of the operation.

![](https://i.imgur.com/cE26OUS.png)

We see that pointwise addition isn't very "intense" at all, which somewhat explains
the measured performance gap.

There's of course a "peak" performance associated with any hardware
and amping up arithmetic intensity gets you closer to the peak.
Once you hit it, though, it doesn't really matter!

![](https://i.imgur.com/wNAMO9v.png)

On the CPU running this benchmark, that threshold is $I = 70$.

This chart may look familiar: it's a [roofline model](https://en.wikipedia.org/wiki/Roofline_model)!

Luckily **neural network architectures like large transformers have extremely high
arithmetic intensity** and exist primarily in the $\pi$ section of the above graphic.

## Modern Accelerators

Up until now the analysis has been pretty hand-wavy about the underlying hardware.
Let's compare workloads of identical number of parameters,
number of floating point operations
and arithmetic intensity.

$$\max(A \cdot B)$$

$$\sum(A \cdot B)$$

They should run in around the same amount of time right?

![](https://i.imgur.com/1LvoOjZ.png)

Let's see if we can figure out what's going on here.

The first thing to note is that the CPU backing this compute has
fma instructions.
This gives the chip the ability to calculate

$$C = A \cdot B + C$$

in a single instruction running at a throughput of 2 instructions per cycle.
That's a 2x speedup over max + mul.

Importantly, FMA only uses 2 vector registers (the third operand
can be a memory address).  To compute both a max and mul we'll need
3 registers.  Since we have 16 total vector registers, that means we
can calculate FMAs with 8-way parallelism.  This is
necessary to achieve the published throughput of 2 instructions per cycle,
since an FMA has a latency of 4 cycles.
max+mul can only be calcuated with 5-way instruction parallelism ($\lfloor{16/3}\rfloor$),
so we have another factor of $\frac{8}{5}$ speedup from using FMA.

That brings us to a total predicted difference of 3.2x, which is within 1% of the measured result!

What's the takeaway here?
**Not all flops are the same.**
In fact, depending on the *type* of floating point operations being used,
you can usually plot multiple different roofline models.
Here are the numbers on an A100 ([from Nvidia's site](https://www.nvidia.com/en-us/data-center/a100/)):

![](https://i.imgur.com/fSKQ289.jpg)

For FMA FP32 workloads, you'll only need an arithmetic intensity of
around 64 to hit "FMA peak".  This translates to square matrix multiplications
of size $400\times 400$.

However, to truly utilize an A100 (and get 8x more performance out of it),
you'll need to use TF32 tensor-cores.  The arithmetic intensity required
for "tensor-core peak"
is 512, or the equivalent of $3000 \times 3000$ matrix multiplications.
BF16 ends up having the same requirement, since the memory movement is reduced as well.

## Conclusion

Modern hardware is not only getting faster but substantially more diverse.
Metrics used to determine your own needs should consider
both the models you're training/using and the characteristics
of the hardware you're planning to use.

I've found doing the math
to figure out "approximate required intensity" for hardware
and the "approximate intensity" of models is often quite instructive.