Processing math: 100%

Sizing Up Neural Nets

by @bwasti (mastodon)


Neural networks clearly aren't getting any smaller. Despite that, actually calculating the "size" of one is ambiguous since there are so many different metrics one could use. This ends up making cost planning (on both the hardware and software side) somewhat tricky to discuss.

Below is a review of common metrics found in the literature that should make it easier to think about neural network costs for your own usecases. There's no "one size fits all" with metrics like this, so it's important to be informed on the benefits and shortcomings of various approaches.

Number of Parameters

The most popular metric I've seen (in papers) is to count the number of parameters being trained. It's super easy to do!

count = 0
for param in parameters:
  count += param.numel

Unfortunately, except for rough comparisons between identical architectures, it's not very useful for fine-grained cost measurements. More parameters doesn't mean you need more GPUs.

Consider matrix addition and matrix multiplication by the same size. Assuming the right hand (B) side of both operations is a learned matrix of size 1000x1000:

C=A+B C=A×B

are these comparable in compute cost? Nope.

That's a 57x difference in raw compute power used on my device.

FLOPs

A popular alternative to parameter counts in the literature is flop counts. A "flop" is a floating point operation. On it's face, this looks like a pretty easy metric to calculate.

def foo(A, B):
  c = 0
  for i in range(100):
    c = max(A[i] * B[i], c)

How many flops is f(A, B)? 200. 100 iterations of two floating point operations: multiply and max.

Compilation

But what if we compile the above code using something like jax.jit?

foo = jax.jit(foo)

Well jax.jit is going to make it run faster most likely, but the total number of flops doesn't change at all. At least, in the case of foo.

Consider bar below:

def bar(a):
  c = 0
  for i in range(100):
    c = c + a

bar(a) clearly has 100 flops (additions), but

bar = jax.jit(bar)

will simplify this down into a single multiplication. So, in reality, we have 1 flop!

Key takeaway, flops must be calculated after compiler optimization.

Interpretability

So, can we use flops as a metric for model size? Assuming we adjust the two workloads (addition vs multiplication) to have the same number of flops, will they run in the same amount of time? Nope.

Now, despite 2M flops each, the matrix multiplication runs 39x faster than the addition on my device!

FLOPs + Arithmetic Intensity

How do we explain the new (reversed) gap in expected performance? Memory access.

Memory is really really slow.

(You can play with this fun site here.)

Modern GPUs have much faster memory than is reflected in the above GIF, but it's still no comparison to the speed of in-register computation.

This is a well understood problem with flop counts. A useful additional metric (common in HPC papers) is arithmetic intensity I:

I=FLOPsBytes

Where Bytes quantifies the total memory traffic of the operation.

We see that pointwise addition isn't very "intense" at all, which somewhat explains the measured performance gap.

There's of course a "peak" performance associated with any hardware and amping up arithmetic intensity gets you closer to the peak. Once you hit it, though, it doesn't really matter!

On the CPU running this benchmark, that threshold is I=70.

This chart may look familiar: it's a roofline model!

Luckily neural network architectures like large transformers have extremely high arithmetic intensity and exist primarily in the π section of the above graphic.

Modern Accelerators

Up until now the analysis has been pretty hand-wavy about the underlying hardware. Let's compare workloads of identical number of parameters, number of floating point operations and arithmetic intensity.

max(AB)

(AB)

They should run in around the same amount of time right?

Let's see if we can figure out what's going on here.

The first thing to note is that the CPU backing this compute has fma instructions. This gives the chip the ability to calculate

C=AB+C

in a single instruction running at a throughput of 2 instructions per cycle. That's a 2x speedup over max + mul.

Importantly, FMA only uses 2 vector registers (the third operand can be a memory address). To compute both a max and mul we'll need 3 registers. Since we have 16 total vector registers, that means we can calculate FMAs with 8-way parallelism. This is necessary to achieve the published throughput of 2 instructions per cycle, since an FMA has a latency of 4 cycles. max+mul can only be calcuated with 5-way instruction parallelism (16/3), so we have another factor of 85 speedup from using FMA.

That brings us to a total predicted difference of 3.2x, which is within 1% of the measured result!

What's the takeaway here? Not all flops are the same. In fact, depending on the type of floating point operations being used, you can usually plot multiple different roofline models. Here are the numbers on an A100 (from Nvidia's site):

For FMA FP32 workloads, you'll only need an arithmetic intensity of around 64 to hit "FMA peak". This translates to square matrix multiplications of size 400×400.

However, to truly utilize an A100 (and get 8x more performance out of it), you'll need to use TF32 tensor-cores. The arithmetic intensity required for "tensor-core peak" is 512, or the equivalent of 3000×3000 matrix multiplications. BF16 ends up having the same requirement, since the memory movement is reduced as well.

Conclusion

Modern hardware is not only getting faster but substantially more diverse. Metrics used to determine your own needs should consider both the models you're training/using and the characteristics of the hardware you're planning to use.

I've found doing the math to figure out "approximate required intensity" for hardware and the "approximate intensity" of models is often quite instructive.

Thanks for reading!

-- @bwasti (mastodon)