## Viewing Euclidean Distance Queries as Matrix Multiplications

I saw an interesting post over at [lobste.rs](https://lobste.rs/s/ub1evt/computing_euclidean_distance_on_144)
 about efficiently
 [computing the Euclidean distance between vectors](https://blog.cloudflare.com/computing-euclidean-distance-on-144-dimensions/)
 using SIMD intrinsics.
 I thought it was a cool problem and
 decided to take a stab at making it fast on a GPU.
 


It boils down to the author wanting to find the closest
 vector in a set of vectors $ DB $ to a given query vector $q$.

$$
\min_{d \in DB} \sum_i \sqrt{(d_i-q_i)^2}
$$

The author notes that squaring is monotonic in the minimum
 function over a positive domain,
 so we can simplify our calculation a bit.  We instead find


$$
\min_{d \in DB} \sum_i (d_i-q_i)^2
$$

which is the element in the database that has the smallest
 Euclidean distance *squared* from the query vector.
 
### Batching

Immediately we can see that the problem as stated is memory bound.
 We can compute a simple subtraction
 and multiplication much faster than we can pull
 a vector from RAM.  To get around this type of issue, we should
 batch the queries.  This way we can reuse the vectors pulled
 from the $DB$ for more than one query and hope to increase processor utilization.
 

However it's not super clear how we can batch the queries...

$$
\forall q \in Q : \min_{d \in DB} \sum_i (d_i-q_i)^2
$$

It looks like we should be able to use three for-loops:
 two outer loops iterating over $q$s and $d$s and one
 inner loop that sums over the $(d_i-q_i)^2$ operation
 for each element within each vector.  Something like:
 
```python
for q in Q:
  for d in DB:
    s = sum((di - qi)**2 for di, qi in zip(d, q))
    if s < min:
      min = d
```
(Please excuse the imprecise use of `min`, which should really be
tracking the index of the minimum value.)
 
### Simplification

But can we do anything simpler than $(d_i-q_i)^2$?
 Removing the square-root was easy enough, can we remove anything else?
 First we expand the terms:
 
$$
\sum_i d_i^2 - \sum_i 2 d_i q_i + \sum_i q_i^2
$$


Since the query vector isn't changing,
 we don't need to consider it's impact on the minimum value.
 So we can remove $\sum_i q_i^2$.
 Now we have:
 
$$
\sum_i d_i^2 - \sum_i 2 d_i q_i
$$
 
 
The $\sum_i d_i^2$ term is unnecessary to calculate for each query,
 so let's precalculate it as $d_{bias} = \sum_i d_i^2$.
 
$$
d_{bias} - \sum_i 2 d_i q_i
$$

Better yet, we can divide the whole equation by 2 and remain monotonic in
 the Euclidean distance.  So we'll update our bias term to be
 
$$
d_{bias} = \frac{1}{2}\sum_i d_i^2
$$

and plug everything back into the original equation:

$$
\forall q \in Q : \min_{d, d_{bias} \in DB} d_{bias} -\sum_i d_i q_i
$$

For simplicity we can also negate the problem and find

$$
\forall q \in Q : \max_{d, d_{bias} \in DB} \sum_i d_i q_i + d_{bias}
$$

with the precalculated bias as

$$
d_{bias} = -\frac{1}{2}\sum_i d_i^2
$$

which can be rewritten as

$$
R = Q \cdot DB + DB_{bias}
$$

where we want to find $\max_{d} R_{d}$.

### Matrix multiplication

This final equation shows us that we can view the problem as a matrix
 multiplication with bias followed by an argmax.
 Luckily there are plenty of libraries that can do this for us
 and since the problem is well specified, it'll be easy to implement
 on GPU.

Here is a full example using floating point numbers with PyTorch, a fast
Python library:

```python
import torch
import time

batch = 1500
db_size = 1000000

print("preprocessing the database")
db = torch.randn(144, db_size).to(torch.device('cuda'))
db_bias = (-0.5 * torch.sum(db**2, axis=0))

print("generating a query")
query = torch.randn(batch,144).to(torch.device('cuda'))

print("running the query")
t = time.time()
result = (query @ db) + db_bias
idxs = torch.argmax(result, axis=1)
print(idxs[0]) # to ensure we sync the result back to CPU
print(f"{batch} queries in {time.time() - t:.4f} seconds")
```

which runs 1500 queries in 70 thousandths of a second on a V100.

### Conclusion

By massaging the problem we were able to find an equivalent
and well studied problem that has efficient solutions
we can leverage easily.
It took 1 line of Python to preprocess our database and
only two lines to run a set of 1500 queries at a rate of 21k queries per second.

Further optimizations can still be applied, such as pruning the
vector space to only the first 32 dimensions (as is done in the
original post) or quantizing the vectors down to low precision
formats like uint8 or uint16 for fast execution on CPU with libraries
like gemmlowp.

Thanks for reading!