# Simplified Array Programming
<center>by <a href="https://twitter.com/bwasti" target="_blank">@bwasti</a></center>
Great strides have been made in optimizing compilers for machine learning.
Notable modern examples are [XLA](https://www.tensorflow.org/xla)
(the backend of [jax](https://jax.readthedocs.io/en/latest/)) and [TVM](https://tvm.apache.org).
Both yield state of the art performance across a variety of applications
and reduce the need for users to carefully consider the low level implications
of their model code.
The increased performance of language backends is great news
for frontend development as it allows us to simplify everything.
Despite that, most frameworks use programming paradigms rooted in the
design of [NumPy](https://en.wikipedia.org/wiki/NumPy).
It's a great framework with an immense ecosystem, but it's from 2006 - predating
these incredible new compilers.
Embracing the cardinal sin of assuming a "[sufficiently smart compiler](https://retrocomputing.stackexchange.com/questions/17633/when-was-the-phrase-sufficiently-smart-compiler-first-used),"
how far can we go with frontend simplification?
### Named Tensors++
Named Tensors are an elegant proposal from Alexander Rush ([original](https://nlp.seas.harvard.edu/NamedTensor), [latest](https://namedtensor.github.io))
that enable users to name dimensions to manipulate tensors.
This is in stark contrast to position based manipulation,
which forces users to rely on comments or
tracking the flow of code to determine what's happening.
This idea was so well received that it was quickly added to PyTorch ([docs](https://pytorch.org/docs/stable/named_tensor.html)).
```
# pytorch
B = A / A.sum(0, 2) # what is happening??
B = A / A.sum("batch", "channel") # ah, spatial dims remain the same
```
Equipped with an optimizing compiler,
we can associate the names with symbols dimensions
and implement all our core primitives with them.
For example, we can implement strided windowed operations by
using our names as a symbolic dimension expression:
```
s = lt.SymbolGenerator()
# stride-2 1d convolution
A = B[s.x + 2 * s.k].sum(s.k)
```
### Symbolic/Lazy Tensors -> Symbolic/Lazy Dimensions
Another popular idea in machine learning frameworks is
symbolic tensors,
which accumulate execution graphs rather than eagerly
execute operations.
Lazy execution enables more advanced
analysis and optimization behind the scenes.
```
# mxnet
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
c = 2 * a + b
fn = c.bind(mx.cpu(), {'a': mx.nd.array([1,2]), 'b':mx.nd.array([2,3])})
# ^ optimizations happen here
out = fn.forward()
```
Symbolic tensors can be found in both
[mxnet](https://mxnet.apache.org/versions/1.4.1/api/python/symbol/symbol.html)
and [PyTorch](https://pytorch.org/docs/stable/fx.html).
We can take the idea of extracting information from traces a step further
by tracing through the implementations of the operations themselves (and their dimensions):
```
def matrix_vector_multiply(X, W):
a, b = X.symbolic_shape[0]
return (X * W.to(b)).sum(b)
```
This symbolic trace gives us more information (information the compiler would
need to derive anyway), including the expected input shape (based on the output).
```
def test_backward_shape_inference():
X = lt.Tensor(lt.Symbol("a"), lt.Symbol("b"))
W = lt.Tensor(128)
Y = matrix_vector_multiply(X, W)
Y.set_size(64)
Y.unify()
# since we expose the impl, this is easily derived by the compiler
assert X.shape == [64, 128]
```
### Embedded vmap
`vmap` is a popular feature of [jax](https://jax.readthedocs.io/en/latest/jax.html#jax.vmap).
It injects dimensions into programs (rather than tensors) and is a good way to
get around the dimensional limits of already written NumPy style APIs:
```
def gen_jax_mm():
vv = lambda x, y: jnp.vdot(x, y) # numpy style dot product
mv = vmap(vv, (0, None), 0) # vmapped to matrix vector
mm = vmap(mv, (None, 1), 1) # vmapped to matrix matrix
return mm
```
If we define all our operations symbolically we can make them
general enough to handle any future inputs without a need to use a vmap primitive.
```
def tc(X, W): # tensor contraction
reduction_dims = set(X.symbolic_shape) & set(W.symbolic_shape)
Y = (X * W).sum(*reduction_dims)
return Y
def test_tensor_contractions():
s = lt.SymbolGenerator()
assert tc(lt.Tensor(s.k), lt.Tensor(s.k)).symbolic_shape == []
assert tc(lt.Tensor(s.k), lt.Tensor(s.m, s.k)).symbolic_shape == [s.m]
assert tc(lt.Tensor(s.m, s.k), lt.Tensor(s.n, s.k)).symbolic_shape == [s.m, s.n]
```
### No more "convXd".
It's a small point, but
[many](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv1D)
[frameworks](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html)
[have](https://mxnet.apache.org/versions/1.1.0/api/python/gluon/nn.html#mxnet.gluon.nn.Conv1D)
[dimension](https://github.com/geohot/tinygrad/blob/master/tinygrad/nn.py#L34)
[specific](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve2d.html)
[convolution](https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/nn/Conv1D_en.html)
[operations](https://numpy.org/doc/stable/reference/generated/numpy.convolve.html).
I believe this is largely an artifact of specific kernel implementations, but it could
be a matter of convenience (eliding the need to specify the window shape).
With a an optimizing compiler and named tensors, dimensionality is a lot more flexible.
Users can pass in whatever they'd like as long as they highlight
which dimensions are to be convolved.
```
# 1d, 2d, 5d convs - whatever
def conv(X, W, spatial, window):
assert len(spatial) == len(window)
# output dimensions need new names
new_spatial = [lt.Symbol(x.name + "o") for x in spatial]
outer = [d for d in X.symbolic_shape if d not in spatial]
exprs = [x + k for x, k in zip(new_spatial, window)]
X = X.to(*outer, *new_spatial, *window, constraints=zip(spatial, exprs))
# reduce over input channels and the windowed dims
reduction_dims = (set(X.symbolic_shape) & set(W.symbolic_shape)) | set(window)
return (X * W).sum(*reduction_dims)
def test_2d():
s = lt.SymbolGenerator()
X = lt.Tensor(s.b, s.ic, s.y, s.x)
W = lt.Tensor(s.oc, s.ic, s.wy, s.wx)
# 2d convolution
Y = conv(X, W, [s.y, s.x], [s.wy, s.wx])
```
### No more NCHW/NHWC
Another artifact of handwritten kernels is the hardcoded (and often global) layouts of tensors.
CuDNN used to [default to NCHW](https://twitter.com/cHHillee/status/1472693287857262592) while TF
defaulted to NHWC. PyTorch defaults to NCHW but has support for NHWC.
Intel uses [nChw8c](https://oneapi-src.github.io/oneDNN/dev_guide_understanding_memory_formats.html#blocked-layout).
It's a mess, for sure, but users should be empathetic to the kernel developers.
It's extremely hard to get away from this model without a good optimizing compiler.
Luckily, this post assumes that exists, so we can enjoy things like this:
```
def test_compiler():
s = lt.SymbolGenerator()
W = lt.Tensor(s.channel_out, s.ky, s.channel_in, s.kx)
X = lt.Tensor(s.x, s.channel_in, s.y, s.batch)
return X, W, conv(X, W, [s.x, s.y], [s.kx, s.ky])
def test():
X, W, Y = test_compiler()
X.set_size(128, 3, 333, 7)
W.set_size(1024, 5, 3, 4)
print(Y.code) # ship it!
```
## Try this out
If you'd like to try these ideas out,
the frontend has been added to (a pre-release of) [loop_tool](https://github.com/facebookresearch/loop_tool).
The performance is still a work in progress, but one can expect
comparable default performance to NumPy (not quite PyTorch/TF yet).
You can either `pip install loop_tool_py` and copy some of the examples above
or use this [notebook](https://colab.research.google.com/drive/1ceDDJeZ9-uHyd0ZDH5UZlZrQXS8sAp_i?usp=sharing).