# Sizing Up Neural Nets

*by [@bwasti](https://twitter.com/bwasti) ([mastodon](https://sigmoid.social/@bwasti))*


Neural networks clearly aren't getting any smaller.
Despite that, actually calculating the "size" of one is ambiguous
since there are so many different metrics one could use. 
This ends up making cost planning
(on both the hardware and software side) somewhat tricky to discuss.

Below is a review of common metrics
found in the literature that should make it easier to think about neural network costs
for your own usecases.
There's no "one size fits all" with metrics like this, so it's important
to be informed on the benefits and shortcomings of various approaches.

## Number of Parameters

The most popular metric I've seen
(in papers) is to count the number of parameters being trained.
It's super easy to do!

count = 0
for param in parameters:
  count += param.numel

Unfortunately, except for rough comparisons between
identical architectures,
it's not very useful for fine-grained cost measurements.
More parameters doesn't mean you need more GPUs.

Consider matrix addition and matrix multiplication
by the same size.  Assuming the right hand ($B$) side of both operations
is a learned matrix of size 1000x1000:

C = A + B
C = A \times B

are these comparable in compute cost? Nope.


That's a 57x difference in raw compute power used on my device.

## FLOPs

A popular alternative to parameter counts
in the literature is flop counts.
A "flop" is a floating point operation.
On it's face, this looks like a pretty easy
metric to calculate.

def foo(A, B):
  c = 0
  for i in range(100):
    c = max(A[i] * B[i], c)

How many flops is `f(A, B)`? 200.
100 iterations of two floating point operations:
multiply and max.

#### Compilation

But what if we compile the above code using something
like `jax.jit`?

foo = jax.jit(foo)

Well `jax.jit` is going to make it run *faster* most likely,
but the total number of flops doesn't change at all.
At least, in the case of `foo`.

Consider `bar` below:

def bar(a):
  c = 0
  for i in range(100):
    c = c + a
`bar(a)` clearly has 100 flops (additions),
bar = jax.jit(bar)
will simplify this down into a single multiplication.
So, in reality, we have 1 flop!

Key takeaway, **flops must be calculated after compiler optimization**.

#### Interpretability

So, can we use flops as a metric for model size?
Assuming we adjust the two workloads (addition vs multiplication) 
to have the same number of flops,
will they run in the same amount of time? Nope.


Now, despite 2M flops each,
the matrix multiplication runs 39x faster than the
addition on my device!

## FLOPs + Arithmetic Intensity

How do we explain the new (reversed) gap in expected performance? Memory access.

Memory is really really slow.

*(You can play with this fun site [here](https://www.overbyte.com.au/misc/Lesson3/CacheFun.html).)*

Modern GPUs have much faster memory than is reflected in the above GIF,
but it's still no comparison to the speed of in-register computation.

This is a well understood problem with flop counts. A useful additional
metric (common in HPC papers) is arithmetic intensity $I$:

I = \frac{FLOPs}{Bytes}

Where $Bytes$ quantifies the total memory traffic of the operation.


We see that pointwise addition isn't very "intense" at all, which somewhat explains
the measured performance gap.

There's of course a "peak" performance associated with any hardware
and amping up arithmetic intensity gets you closer to the peak.
Once you hit it, though, it doesn't really matter!


On the CPU running this benchmark, that threshold is $I = 70$.

This chart may look familiar: it's a [roofline model](https://en.wikipedia.org/wiki/Roofline_model)!


Luckily **neural network architectures like large transformers have extremely high
arithmetic intensity** and exist primarily in the $\pi$ section of the above graphic.

## Modern Accelerators

Up until now the analysis has been pretty hand-wavy about the underlying hardware.
Let's compare workloads of identical number of parameters,
number of floating point operations
and arithmetic intensity.

\max(A \cdot B)

\sum(A \cdot B)

They should run in around the same amount of time right?


Let's see if we can figure out what's going on here.

The first thing to note is that the CPU backing this compute has
`fma` instructions.
This gives the chip the ability to calculate

C = A \cdot B + C

in a single instruction running at a throughput of 2 instructions per cycle.
That's a 2x speedup over `max` + `mul`.

Importantly, FMA only uses 2 vector registers (the third operand
can be a memory address).  To compute both a `max` and `mul` we'll need
3 registers.  Since we have 16 total vector registers, that means we
can calculate FMAs with 8-way parallelism.  This is
necessary to achieve the published throughput of 2 instructions per cycle,
since an FMA has a latency of 4 cycles.
`max`+`mul` can only be calcuated with 5-way instruction parallelism ($\lfloor{16/3}\rfloor$),
so we have another factor of $\frac{8}{5}$ speedup from using FMA.

That brings us to a total predicted difference of 3.2x, which is within 1% of the measured result!

What's the takeaway here?
**Not all flops are the same.**
In fact, depending on the *type* of floating point operations being used,
you can usually plot multiple different roofline models.
Here are the numbers on an A100 ([from Nvidia's site](https://www.nvidia.com/en-us/data-center/a100/)):


For FMA FP32 workloads, you'll only need an arithmetic intensity of
around 64 to hit "FMA peak".  This translates to square matrix multiplications
of size $400\times 400$.

However, to truly utilize an A100 (and get 8x more performance out of it),
you'll need to use TF32 tensor-cores.  The arithmetic intensity required 
for "tensor-core peak"
is 512, or the equivalent of $3000 \times 3000$ matrix multiplications.
BF16 ends up having the same requirement, since the memory movement is reduced as well.

## Conclusion

Modern hardware is not only getting faster but substantially more diverse.
Metrics used to determine your own needs should consider
both the models you're training/using and the characteristics
of the hardware you're planning to use.

I've found doing the math
to figure out "approximate required intensity" for hardware
and the "approximate intensity" of models is often quite instructive.

Thanks for reading!

-- [@bwasti](http://twitter.com/bwasti) ([mastodon](https://sigmoid.social/@bwasti))