# `as_strided` and `sum` are all you need

*...to implement the non-pointwise operations in a neural network.*

<center>by <a href="https://twitter.com/bwasti" target="_blank">@bwasti</a></center>

Luckily, both popular frameworks PyTorch and Numpy have the `as_strided` and `sum`
primitives, so this can be demonstrated with runnable code.

## `sum`

`sum` is a generalization of the familiar Python function,
which is an additive reduction over a list.


```python
sum([1,2,3]) == torch.tensor([1,2,3]).sum()
```
On an N-dimensional tensor,
`sum` lets you specify which dimensions to reduce over.

```python
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.0569, -0.2475,  0.0737, -0.3429],
        [-0.2993,  0.9138,  0.9337, -1.6864],
        [ 0.1132,  0.7892, -0.1003,  0.5688],
        [ 0.3637, -0.9906, -0.4752, -1.5197]])
>>> torch.sum(a, 1)
tensor([-0.4598, -0.1381,  1.3708, -2.6217])
```

## `as_strided`

`as_strided` is trickier to explain.
It allows users to reinterpet input tensors by specifying custom
sizes, strides and offsets for the dimensions of the output.

PyTorch's docs give a great example:

```python
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.9039,  0.6291,  1.0795],
        [ 0.1586,  2.1939, -0.4900],
        [-0.1909, -0.7503,  1.9355]])
>>> t = torch.as_strided(x, (2, 2), (1, 2))
>>> t
tensor([[0.9039, 1.0795],
        [0.6291, 0.1586]])
>>> t = torch.as_strided(x, (2, 2), (1, 2), 1)
tensor([[0.6291, 0.1586],
        [1.0795, 2.1939]])
```

A 3 by 3 tensor is reinterpreted as a 2 by 2 tensor and everything is
shifted around.  Clearly this is a powerful and dangerous operation.
It's *so* dangerous that both PyTorch and Numpy have explicit warnings
in their documentation:

![](https://i.imgur.com/Vdw9ddm.png)
![](https://i.imgur.com/NMwfuaa.png)

That means we're dealing with the good stuff.

## Matrix trace

Although matrix traces rarely find themselves in neural networks,
they're good for reasoning about `as_strided`.

A trace is a sum over the diagonal of a square matrix:
$$
\mathbf{A} =
  \begin{pmatrix}
    a_{11} & a_{12} & a_{13} \\\a_{21} & a_{22} & a_{23} \\\a_{31} & a_{32} & a_{33}
  \end{pmatrix}
$$

$$
\operatorname{tr}(\mathbf{A}) = \sum_{i=1}^{3} a_{ii} = a_{11} + a_{22} + a_{33}
$$

```python
m = torch.randn(16, 16)
tr_m = torch.trace(m)
```
Implemented as a for loop:

```python
tr_m = 0
for x in range(16):
  idx = x * 16 + x
  tr_m += m.flatten()[idx]
```
The key operation is the indexing math, which we can split out as a single
equation for a single dimension:
$$
index = x \cdot |x| + x
$$

Generally, once the index equation is isolated, we can ditch the for loop
and use `as_strided`.
The trick is to differentiate with respect to the loop variable ($x$ in this case),
and set that as the stride.

$$
\frac{\partial index}{\partial x}
   = |x| + 1
$$

```python
n = m.clone()
tr_n = torch.as_strided(n, (16,), (16 + 1,)).sum()
torch.testing.assert_allclose(tr_m, tr_n)
```

Cool right? Maybe not yet. It gets better I promise.

## Outer product

Now, we can jump into two dimensional index equations.
The outer product is another rarely employed operation in neural networks,
but its implementation demonstrates how strides can be used
for broadcasting to higher dimensions.

$$
\mathbf{u} =\begin{bmatrix} u_1 \\\u_2 \\\ \vdots \\\u_m \end{bmatrix},
\quad
\mathbf{v} = \begin{bmatrix} v_1 \\\v_2 \\\ \vdots \\\ v_n \end{bmatrix}
$$

$$
\mathbf{u} \otimes \mathbf{v} = \\\ \begin{bmatrix}
    u_1v_1 & u_1v_2 &  \dots & u_1v_n \\\ u_2v_1 & u_2v_2 &  \dots & u_2v_n \\\ \vdots & \vdots & \ddots & \vdots \\\ u_mv_1 & u_mv_2 &  \dots & u_mv_n
  \end{bmatrix}
$$

As a for loop:

```python
for i in range(16):
  for j in range(16):
    outer[i,j] = u[i] * v[j]
```

It's bit easier to see the index equations for `u` and `v` once you
add in the variables of the loops:

```python
for i in range(16):
  for j in range(16):
    outer[i,j] = u[i + 0 * j] * v[0 * i + j]
```

So our index equations are

$$
idx_u = i + 0 \cdot j,
\quad
idx_v = 0 \cdot i + j
$$

Differentiate to find the strides of $idx_u$

$$
\frac{\partial idx_u}{\partial i} = 1,
\quad
\frac{\partial idx_u}{\partial j} = 0
$$

and do the same for $idx_v$ to get this `as_strided` implementation:

```python
u = torch.randn(16)
v = torch.randn(16)

outer_0 = torch.outer(u, v)
outer_1 = torch.as_strided(u, (16, 16), (1, 0)) * torch.as_strided(v, (16, 16), (0, 1))
torch.testing.assert_allclose(outer_0, outer_1)
```

## The real stuff

Finally we can talk about the real stuff: matrix multiplications and convolutions.

### Matrix multiplication


In machine learning matrix multiplications are found in many places
(transformers, MLPs) and go by many names (dense, linear, dot, tensor contraction).
A matrix multiplication is a straight forward extension of the previous two examples:

```python
for m in range(16):
 for k in range(16):
  for n in range(16):
    C[m * 16 + n] += A[m * 16 + k + n * 0] * B[m * 0 + k * 16 + n]
```

$$
idx_a = 16 \cdot m + k + 0 \cdot n
$$

$$
\frac{\partial idx_a}{\partial m} = 16,
\quad
\frac{\partial idx_a}{\partial k} = 1,
\quad\\\ \frac{\partial idx_a}{\partial n} = 0
$$
(not shown, but the same idea holds for $idx_b$)
```python
a = torch.randn(16, 16)
b = torch.randn(16, 16)

mm_0 = torch.mm(a, b)
mm_1 = (torch.as_strided(a, (16, 16, 16), (16, 1, 0)) * torch.as_strided(b, (16, 16, 16), (0, 16, 1))).sum(1)

torch.testing.assert_allclose(mm_0, mm_1)
```

We can see that it's effectively a 3d outer product followed by a summation.


### Convolution

Here's a cool one.  Pretty much any convolution you'd want
can be implemented with `as_strided` + `sum`.
Even crazy ones like strided,
asymmetrically padded, dilated, grouped 3d convolutions (left as an exercise to the reader).

To review, a convolution is a sweeping operation that aggregates
windows of values.  The graphic below shows a 2d convolution without strides or dilation.


![](https://miro.medium.com/max/1400/1*Fw-ehcNBR9byHtho-Rxbtw.gif)
([image source](https://towardsdatascience.com/intuitively-understanding-convolutions-for-deep-learning-1f6f42faee1))


Before getting into how 2d convolutions work, let's start with one dimension.
It can be written in a loop like so:

```python
for i in range(14):
  for k in range(3):
    Y[i] += X[i + k]
```

The index equation is

$$
index = i + k
$$

so,

$$
\frac{\partial index}{\partial i} = 1,
\quad
\frac{\partial index}{\partial k} = 1
$$

On it's face, this doesn't seem to have enough information to implement the operation.
There's nothing about the window of size 3 sweeping over the input.

That information is contained in the output shape of the call to `as_strided`:
(14, 3)  Much like matrix multiply, we're broadcasting into a higher dimensional
tensor and then reducing over the newly created axis.

```python
X = torch.randn(16)

conv_0 = torch.as_strided(X, (14, 3), (1, 1)).sum(axis=1)
conv_1 = torch.nn.functional.conv1d(X.reshape(1,1,16), torch.ones(1,1,3)).flatten()

torch.testing.assert_allclose(conv_0, conv_1)
```

Extending to a two dimensional, strided-by-2 convolution,
we get the index equations

$$
idx_x = 2 \cdot x + kx,
\quad\\\ idx_y = 2 \cdot y + ky
$$

For $idx_x$ we get

$$
\frac{\partial idx_x}{\partial x} = 2,
\quad
\frac{\partial idx_x}{\partial k} = 1
$$

We can keep this 2d by implicitly multiplying the derived result
by the original strides of the input axis.
Since $x$ is the leading dimension, it already has a stride of $|y|$.

```python
X = torch.randn(16, 16)

sconv2d_0 = torch.as_strided(X, (7, 7, 3, 3), (2 * 16, 2, 16, 1)).sum(axis=(2,3))
sconv2d_1 = torch.nn.functional.conv2d(X.reshape(1,1,16,16), torch.ones(1,1,3,3), stride=2).reshape(7,7)

torch.testing.assert_allclose(sconv2d_0, sconv2d_1)
```

Interestingly, Tinygrad uses this exact approach for its
[convolution implementation](https://github.com/geohot/tinygrad/blob/ad756f611279a74e7133ea54f4243c9dc3313078/tinygrad/ops_cpu.py#L167-L171):

```python
tx = np.lib.stride_tricks.as_strided(gx,
  shape=(bs, ctx.groups, cin, oy, ox, H, W),
  strides=(*gx.strides[0:3], gx.strides[3]*ys, gx.strides[4]*xs, *gx.strides[3:5]),
  writeable=False,
)
```


## Framework Support

Although many operations can be expressed this way, there aren't any frameworks
that directly optimize for the `as_strided` + `sum` pair of operations.
As a result, the performance can be quite poor.

If you're interested in the approach,
I've been working on the experimental framework [`loop_tool`](https://github.com/facebookresearch/loop_tool)
that really runs with this idea.
The API is symbolic and the differentiation is done for you,
but performance is still a work in progress. (`pip install loop_tool_py`)

Here's a 1d convolution in `loop_tool`:
```python
x_ik = x.to(i, k, constraints=[(idx, i + k)])
y = (x_ik * w).sum(k)
```
(An interactive tutorial can be found on [colab](https://colab.research.google.com/gist/bwasti/d88f345f5a106ca935bb55aa1baf4924/loop_tool_demo.ipynb))

If you're not interested in using a framework,
the underlying loops with strided access are not too painful.
For example, a stride-by-3, dilated convolution in C++:

```python
for (auto i = 0; i < I; ++i) {
  for (auto k = 0; k < K; ++k) {
    C[i] += X[i * 3 + k * 2] * W[i * 0 + k * 1];
  }
}
```
When optimizing by hand, just be sure to keep in mind which axes are reductions.
To reorder over reductions (induce parallelism), you can check out Halide's
[rfactor tutorial](https://halide-lang.org/tutorials/tutorial_lesson_18_parallel_associative_reductions.html).