as_strided
and sum
are all you need...to implement the non-pointwise operations in a neural network.
Luckily, both popular frameworks PyTorch and Numpy have the as_strided
and sum
primitives, so this can be demonstrated with runnable code.
sum
sum
is a generalization of the familiar Python function,
which is an additive reduction over a list.
sum([1,2,3]) == torch.tensor([1,2,3]).sum()
On an N-dimensional tensor,
sum
lets you specify which dimensions to reduce over.
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.0569, -0.2475, 0.0737, -0.3429],
[-0.2993, 0.9138, 0.9337, -1.6864],
[ 0.1132, 0.7892, -0.1003, 0.5688],
[ 0.3637, -0.9906, -0.4752, -1.5197]])
>>> torch.sum(a, 1)
tensor([-0.4598, -0.1381, 1.3708, -2.6217])
as_strided
as_strided
is trickier to explain.
It allows users to reinterpet input tensors by specifying custom
sizes, strides and offsets for the dimensions of the output.
PyTorch's docs give a great example:
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.9039, 0.6291, 1.0795],
[ 0.1586, 2.1939, -0.4900],
[-0.1909, -0.7503, 1.9355]])
>>> t = torch.as_strided(x, (2, 2), (1, 2))
>>> t
tensor([[0.9039, 1.0795],
[0.6291, 0.1586]])
>>> t = torch.as_strided(x, (2, 2), (1, 2), 1)
tensor([[0.6291, 0.1586],
[1.0795, 2.1939]])
A 3 by 3 tensor is reinterpreted as a 2 by 2 tensor and everything is shifted around. Clearly this is a powerful and dangerous operation. It's so dangerous that both PyTorch and Numpy have explicit warnings in their documentation:
That means we're dealing with the good stuff.
Although matrix traces rarely find themselves in neural networks,
they're good for reasoning about as_strided
.
A trace is a sum over the diagonal of a square matrix: A=(a11a12a13a21a22a23a31a32a33)
tr(A)=3∑i=1aii=a11+a22+a33
m = torch.randn(16, 16)
tr_m = torch.trace(m)
Implemented as a for loop:
tr_m = 0
for x in range(16):
idx = x * 16 + x
tr_m += m.flatten()[idx]
The key operation is the indexing math, which we can split out as a single equation for a single dimension: index=x⋅|x|+x
Generally, once the index equation is isolated, we can ditch the for loop
and use as_strided
.
The trick is to differentiate with respect to the loop variable (x in this case),
and set that as the stride.
∂index∂x=|x|+1
n = m.clone()
tr_n = torch.as_strided(n, (16,), (16 + 1,)).sum()
torch.testing.assert_allclose(tr_m, tr_n)
Cool right? Maybe not yet. It gets better I promise.
Now, we can jump into two dimensional index equations. The outer product is another rarely employed operation in neural networks, but its implementation demonstrates how strides can be used for broadcasting to higher dimensions.
u=[u1u2⋮um],v=[v1v2⋮vn]
u⊗v=[u1v1u1v2…u1vnu2v1u2v2…u2vn⋮⋮⋱⋮umv1umv2…umvn]
As a for loop:
for i in range(16):
for j in range(16):
outer[i,j] = u[i] * v[j]
It's bit easier to see the index equations for u
and v
once you
add in the variables of the loops:
for i in range(16):
for j in range(16):
outer[i,j] = u[i + 0 * j] * v[0 * i + j]
So our index equations are
idxu=i+0⋅j,idxv=0⋅i+j
Differentiate to find the strides of idxu
∂idxu∂i=1,∂idxu∂j=0
and do the same for idxv to get this as_strided
implementation:
u = torch.randn(16)
v = torch.randn(16)
outer_0 = torch.outer(u, v)
outer_1 = torch.as_strided(u, (16, 16), (1, 0)) * torch.as_strided(v, (16, 16), (0, 1))
torch.testing.assert_allclose(outer_0, outer_1)
Finally we can talk about the real stuff: matrix multiplications and convolutions.
In machine learning matrix multiplications are found in many places (transformers, MLPs) and go by many names (dense, linear, dot, tensor contraction). A matrix multiplication is a straight forward extension of the previous two examples:
for m in range(16):
for k in range(16):
for n in range(16):
C[m * 16 + n] += A[m * 16 + k + n * 0] * B[m * 0 + k * 16 + n]
idxa=16⋅m+k+0⋅n
∂idxa∂m=16,∂idxa∂k=1,∂idxa∂n=0 (not shown, but the same idea holds for idxb)
a = torch.randn(16, 16)
b = torch.randn(16, 16)
mm_0 = torch.mm(a, b)
mm_1 = (torch.as_strided(a, (16, 16, 16), (16, 1, 0)) * torch.as_strided(b, (16, 16, 16), (0, 16, 1))).sum(1)
torch.testing.assert_allclose(mm_0, mm_1)
We can see that it's effectively a 3d outer product followed by a summation.
Here's a cool one. Pretty much any convolution you'd want
can be implemented with as_strided
+ sum
.
Even crazy ones like strided,
asymmetrically padded, dilated, grouped 3d convolutions (left as an exercise to the reader).
To review, a convolution is a sweeping operation that aggregates windows of values. The graphic below shows a 2d convolution without strides or dilation.
Before getting into how 2d convolutions work, let's start with one dimension. It can be written in a loop like so:
for i in range(14):
for k in range(3):
Y[i] += X[i + k]
The index equation is
index=i+k
so,
∂index∂i=1,∂index∂k=1
On it's face, this doesn't seem to have enough information to implement the operation. There's nothing about the window of size 3 sweeping over the input.
That information is contained in the output shape of the call to as_strided
:
(14, 3) Much like matrix multiply, we're broadcasting into a higher dimensional
tensor and then reducing over the newly created axis.
X = torch.randn(16)
conv_0 = torch.as_strided(X, (14, 3), (1, 1)).sum(axis=1)
conv_1 = torch.nn.functional.conv1d(X.reshape(1,1,16), torch.ones(1,1,3)).flatten()
torch.testing.assert_allclose(conv_0, conv_1)
Extending to a two dimensional, strided-by-2 convolution, we get the index equations
idxx=2⋅x+kx,idxy=2⋅y+ky
For idxx we get
∂idxx∂x=2,∂idxx∂k=1
We can keep this 2d by implicitly multiplying the derived result by the original strides of the input axis. Since x is the leading dimension, it already has a stride of |y|.
X = torch.randn(16, 16)
sconv2d_0 = torch.as_strided(X, (7, 7, 3, 3), (2 * 16, 2, 16, 1)).sum(axis=(2,3))
sconv2d_1 = torch.nn.functional.conv2d(X.reshape(1,1,16,16), torch.ones(1,1,3,3), stride=2).reshape(7,7)
torch.testing.assert_allclose(sconv2d_0, sconv2d_1)
Interestingly, Tinygrad uses this exact approach for its convolution implementation:
tx = np.lib.stride_tricks.as_strided(gx,
shape=(bs, ctx.groups, cin, oy, ox, H, W),
strides=(*gx.strides[0:3], gx.strides[3]*ys, gx.strides[4]*xs, *gx.strides[3:5]),
writeable=False,
)
Although many operations can be expressed this way, there aren't any frameworks
that directly optimize for the as_strided
+ sum
pair of operations.
As a result, the performance can be quite poor.
If you're interested in the approach,
I've been working on the experimental framework loop_tool
that really runs with this idea.
The API is symbolic and the differentiation is done for you,
but performance is still a work in progress. (pip install loop_tool_py
)
Here's a 1d convolution in loop_tool
:
x_ik = x.to(i, k, constraints=[(idx, i + k)])
y = (x_ik * w).sum(k)
(An interactive tutorial can be found on colab)
If you're not interested in using a framework, the underlying loops with strided access are not too painful. For example, a stride-by-3, dilated convolution in C++:
for (auto i = 0; i < I; ++i) {
for (auto k = 0; k < K; ++k) {
C[i] += X[i * 3 + k * 2] * W[i * 0 + k * 1];
}
}
When optimizing by hand, just be sure to keep in mind which axes are reductions. To reorder over reductions (induce parallelism), you can check out Halide's rfactor tutorial.