as_strided and sum are all you need

...to implement the non-pointwise operations in a neural network.

by @bwasti

Luckily, both popular frameworks PyTorch and Numpy have the as_strided and sum primitives, so this can be demonstrated with runnable code.

sum

sum is a generalization of the familiar Python function, which is an additive reduction over a list.

sum([1,2,3]) == torch.tensor([1,2,3]).sum()

On an N-dimensional tensor, sum lets you specify which dimensions to reduce over.

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.0569, -0.2475,  0.0737, -0.3429],
        [-0.2993,  0.9138,  0.9337, -1.6864],
        [ 0.1132,  0.7892, -0.1003,  0.5688],
        [ 0.3637, -0.9906, -0.4752, -1.5197]])
>>> torch.sum(a, 1)
tensor([-0.4598, -0.1381,  1.3708, -2.6217])

as_strided

as_strided is trickier to explain. It allows users to reinterpet input tensors by specifying custom sizes, strides and offsets for the dimensions of the output.

PyTorch's docs give a great example:

>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.9039,  0.6291,  1.0795],
        [ 0.1586,  2.1939, -0.4900],
        [-0.1909, -0.7503,  1.9355]])
>>> t = torch.as_strided(x, (2, 2), (1, 2))
>>> t
tensor([[0.9039, 1.0795],
        [0.6291, 0.1586]])
>>> t = torch.as_strided(x, (2, 2), (1, 2), 1)
tensor([[0.6291, 0.1586],
        [1.0795, 2.1939]])

A 3 by 3 tensor is reinterpreted as a 2 by 2 tensor and everything is shifted around. Clearly this is a powerful and dangerous operation. It's so dangerous that both PyTorch and Numpy have explicit warnings in their documentation:

That means we're dealing with the good stuff.

Matrix trace

Although matrix traces rarely find themselves in neural networks, they're good for reasoning about as_strided.

A trace is a sum over the diagonal of a square matrix: $$ \mathbf{A} = \begin{pmatrix} a_{11} & a_{12} & a_{13} \\a_{21} & a_{22} & a_{23} \\a_{31} & a_{32} & a_{33} \end{pmatrix} $$

$$ \operatorname{tr}(\mathbf{A}) = \sum_{i=1}^{3} a_{ii} = a_{11} + a_{22} + a_{33} $$

m = torch.randn(16, 16)
tr_m = torch.trace(m)

Implemented as a for loop:

tr_m = 0
for x in range(16):
  idx = x * 16 + x
  tr_m += m.flatten()[idx]

The key operation is the indexing math, which we can split out as a single equation for a single dimension: $$ index = x \cdot |x| + x $$

Generally, once the index equation is isolated, we can ditch the for loop and use as_strided. The trick is to differentiate with respect to the loop variable ($x$ in this case), and set that as the stride.

$$ \frac{\partial index}{\partial x} = |x| + 1 $$

n = m.clone()
tr_n = torch.as_strided(n, (16,), (16 + 1,)).sum()
torch.testing.assert_allclose(tr_m, tr_n)

Cool right? Maybe not yet. It gets better I promise.

Outer product

Now, we can jump into two dimensional index equations. The outer product is another rarely employed operation in neural networks, but its implementation demonstrates how strides can be used for broadcasting to higher dimensions.

$$ \mathbf{u} =\begin{bmatrix} u_1 \\u_2 \\ \vdots \\u_m \end{bmatrix}, \quad \mathbf{v} = \begin{bmatrix} v_1 \\v_2 \\ \vdots \\ v_n \end{bmatrix} $$

$$ \mathbf{u} \otimes \mathbf{v} = \\ \begin{bmatrix} u_1v_1 & u_1v_2 & \dots & u_1v_n \\ u_2v_1 & u_2v_2 & \dots & u_2v_n \\ \vdots & \vdots & \ddots & \vdots \\ u_mv_1 & u_mv_2 & \dots & u_mv_n \end{bmatrix} $$

As a for loop:

for i in range(16):
  for j in range(16):
    outer[i,j] = u[i] * v[j]

It's bit easier to see the index equations for u and v once you add in the variables of the loops:

for i in range(16):
  for j in range(16):
    outer[i,j] = u[i + 0 * j] * v[0 * i + j]

So our index equations are

$$ idx_u = i + 0 \cdot j, \quad idx_v = 0 \cdot i + j $$

Differentiate to find the strides of $idx_u$

$$ \frac{\partial idx_u}{\partial i} = 1, \quad \frac{\partial idx_u}{\partial j} = 0 $$

and do the same for $idx_v$ to get this as_strided implementation:

u = torch.randn(16)
v = torch.randn(16)

outer_0 = torch.outer(u, v)
outer_1 = torch.as_strided(u, (16, 16), (1, 0)) * torch.as_strided(v, (16, 16), (0, 1))
torch.testing.assert_allclose(outer_0, outer_1)

The real stuff

Finally we can talk about the real stuff: matrix multiplications and convolutions.

Matrix multiplication

In machine learning matrix multiplications are found in many places (transformers, MLPs) and go by many names (dense, linear, dot, tensor contraction). A matrix multiplication is a straight forward extension of the previous two examples:

for m in range(16):
 for k in range(16):
  for n in range(16):
    C[m * 16 + n] += A[m * 16 + k + n * 0] * B[m * 0 + k * 16 + n]

$$ idx_a = 16 \cdot m + k + 0 \cdot n $$

$$ \frac{\partial idx_a}{\partial m} = 16, \quad \frac{\partial idx_a}{\partial k} = 1, \quad\\ \frac{\partial idx_a}{\partial n} = 0 $$ (not shown, but the same idea holds for $idx_b$)

a = torch.randn(16, 16)
b = torch.randn(16, 16)

mm_0 = torch.mm(a, b)
mm_1 = (torch.as_strided(a, (16, 16, 16), (16, 1, 0)) * torch.as_strided(b, (16, 16, 16), (0, 16, 1))).sum(1)

torch.testing.assert_allclose(mm_0, mm_1)

We can see that it's effectively a 3d outer product followed by a summation.

Convolution

Here's a cool one. Pretty much any convolution you'd want can be implemented with as_strided + sum. Even crazy ones like strided, asymmetrically padded, dilated, grouped 3d convolutions (left as an exercise to the reader).

To review, a convolution is a sweeping operation that aggregates windows of values. The graphic below shows a 2d convolution without strides or dilation.

(image source)

Before getting into how 2d convolutions work, let's start with one dimension. It can be written in a loop like so:

for i in range(14):
  for k in range(3):
    Y[i] += X[i + k]

The index equation is

$$ index = i + k $$

so,

$$ \frac{\partial index}{\partial i} = 1, \quad \frac{\partial index}{\partial k} = 1 $$

On it's face, this doesn't seem to have enough information to implement the operation. There's nothing about the window of size 3 sweeping over the input.

That information is contained in the output shape of the call to as_strided: (14, 3) Much like matrix multiply, we're broadcasting into a higher dimensional tensor and then reducing over the newly created axis.

X = torch.randn(16)

conv_0 = torch.as_strided(X, (14, 3), (1, 1)).sum(axis=1)
conv_1 = torch.nn.functional.conv1d(X.reshape(1,1,16), torch.ones(1,1,3)).flatten()

torch.testing.assert_allclose(conv_0, conv_1)

Extending to a two dimensional, strided-by-2 convolution, we get the index equations

$$ idx_x = 2 \cdot x + kx, \quad\\ idx_y = 2 \cdot y + ky $$

For $idx_x$ we get

$$ \frac{\partial idx_x}{\partial x} = 2, \quad \frac{\partial idx_x}{\partial k} = 1 $$

We can keep this 2d by implicitly multiplying the derived result by the original strides of the input axis. Since $x$ is the leading dimension, it already has a stride of $|y|$.

X = torch.randn(16, 16)

sconv2d_0 = torch.as_strided(X, (7, 7, 3, 3), (2 * 16, 2, 16, 1)).sum(axis=(2,3))
sconv2d_1 = torch.nn.functional.conv2d(X.reshape(1,1,16,16), torch.ones(1,1,3,3), stride=2).reshape(7,7)

torch.testing.assert_allclose(sconv2d_0, sconv2d_1)

Interestingly, Tinygrad uses this exact approach for its convolution implementation:

tx = np.lib.stride_tricks.as_strided(gx,
  shape=(bs, ctx.groups, cin, oy, ox, H, W),
  strides=(*gx.strides[0:3], gx.strides[3]*ys, gx.strides[4]*xs, *gx.strides[3:5]),
  writeable=False,
)

Framework Support

Although many operations can be expressed this way, there aren't any frameworks that directly optimize for the as_strided + sum pair of operations. As a result, the performance can be quite poor.

If you're interested in the approach, I've been working on the experimental framework loop_tool that really runs with this idea. The API is symbolic and the differentiation is done for you, but performance is still a work in progress. (pip install loop_tool_py)

Here's a 1d convolution in loop_tool:

x_ik = x.to(i, k, constraints=[(idx, i + k)])
y = (x_ik * w).sum(k)

(An interactive tutorial can be found on colab)

If you're not interested in using a framework, the underlying loops with strided access are not too painful. For example, a stride-by-3, dilated convolution in C++:

for (auto i = 0; i < I; ++i) {
  for (auto k = 0; k < K; ++k) {
    C[i] += X[i * 3 + k * 2] * W[i * 0 + k * 1];
  }
}

When optimizing by hand, just be sure to keep in mind which axes are reductions. To reorder over reductions (induce parallelism), you can check out Halide's rfactor tutorial.