I saw an interesting post over at lobste.rs about efficiently computing the Euclidean distance between vectors using SIMD intrinsics. I thought it was a cool problem and decided to take a stab at making it fast on a GPU.
It boils down to the author wanting to find the closest vector in a set of vectors DB to a given query vector q.
min
The author notes that squaring is monotonic in the minimum function over a positive domain, so we can simplify our calculation a bit. We instead find
\min_{d \in DB} \sum_i (d_i-q_i)^2
which is the element in the database that has the smallest Euclidean distance squared from the query vector.
Immediately we can see that the problem as stated is memory bound. We can compute a simple subtraction and multiplication much faster than we can pull a vector from RAM. To get around this type of issue, we should batch the queries. This way we can reuse the vectors pulled from the DB for more than one query and hope to increase processor utilization.
However it's not super clear how we can batch the queries...
\forall q \in Q : \min_{d \in DB} \sum_i (d_i-q_i)^2
It looks like we should be able to use three for-loops: two outer loops iterating over qs and ds and one inner loop that sums over the (d_i-q_i)^2 operation for each element within each vector. Something like:
for q in Q:
for d in DB:
s = sum((di - qi)**2 for di, qi in zip(d, q))
if s < min:
min = d
(Please excuse the imprecise use of min
, which should really be
tracking the index of the minimum value.)
But can we do anything simpler than (d_i-q_i)^2? Removing the square-root was easy enough, can we remove anything else? First we expand the terms:
\sum_i d_i^2 - \sum_i 2 d_i q_i + \sum_i q_i^2
Since the query vector isn't changing, we don't need to consider it's impact on the minimum value. So we can remove \sum_i q_i^2. Now we have:
\sum_i d_i^2 - \sum_i 2 d_i q_i
The \sum_i d_i^2 term is unnecessary to calculate for each query, so let's precalculate it as d_{bias} = \sum_i d_i^2.
d_{bias} - \sum_i 2 d_i q_i
Better yet, we can divide the whole equation by 2 and remain monotonic in the Euclidean distance. So we'll update our bias term to be
d_{bias} = \frac{1}{2}\sum_i d_i^2
and plug everything back into the original equation:
\forall q \in Q : \min_{d, d_{bias} \in DB} d_{bias} -\sum_i d_i q_i
For simplicity we can also negate the problem and find
\forall q \in Q : \max_{d, d_{bias} \in DB} \sum_i d_i q_i + d_{bias}
with the precalculated bias as
d_{bias} = -\frac{1}{2}\sum_i d_i^2
which can be rewritten as
R = Q \cdot DB + DB_{bias}
where we want to find \max_{d} R_{d}.
This final equation shows us that we can view the problem as a matrix multiplication with bias followed by an argmax. Luckily there are plenty of libraries that can do this for us and since the problem is well specified, it'll be easy to implement on GPU.
Here is a full example using floating point numbers with PyTorch, a fast Python library:
import torch
import time
batch = 1500
db_size = 1000000
print("preprocessing the database")
db = torch.randn(144, db_size).to(torch.device('cuda'))
db_bias = (-0.5 * torch.sum(db**2, axis=0))
print("generating a query")
query = torch.randn(batch,144).to(torch.device('cuda'))
print("running the query")
t = time.time()
result = (query @ db) + db_bias
idxs = torch.argmax(result, axis=1)
print(idxs[0]) # to ensure we sync the result back to CPU
print(f"{batch} queries in {time.time() - t:.4f} seconds")
which runs 1500 queries in 70 thousandths of a second on a V100.
By massaging the problem we were able to find an equivalent and well studied problem that has efficient solutions we can leverage easily. It took 1 line of Python to preprocess our database and only two lines to run a set of 1500 queries at a rate of 21k queries per second.
Further optimizations can still be applied, such as pruning the vector space to only the first 32 dimensions (as is done in the original post) or quantizing the vectors down to low precision formats like uint8 or uint16 for fast execution on CPU with libraries like gemmlowp.
Thanks for reading!