Safari supports WebGPU experimentally with WSL kernels. I wrote a simple tuner that tries to optimize matrix multiplication. If you have Safari, you can try it here. (You'll need to enable WebGPU in Develop > Experimental Features.)
My M1 MacBook Air achieves 900GFlops after a couple seconds of tuning. My Intel MacBook Pro (16-inch, 2019, i9) only hits 100GFlops with the same exhaustive search.
For reference, MobileNet v3 Large (x1.0) is ~219MFlops. Running at this performance, it could do 4,500 inferences per second. The base BERT model (12 layers) is 11.2 GFlops. At this perf, one could theoretically run it 90 times a second.
The tuning code can be found here.
The basic idea is to tile memory accesses,
vectorize, use mad
instructions and tune for threading and dispatch
parameters.
The result is a kernel that looks like this:
[numthreads(2, 8, 1)]
compute void main(constant float4[] A : register(u0),
constant float4[] B : register(u1),
device float4[] C : register(u2),
float3 threadID : SV_DispatchThreadID) {
uint m = uint(threadID.x);
uint n = uint(threadID.y);
float4 result_0_0 = float4(0.0, 0.0, 0.0, 0.0);
float4 result_1_0 = float4(0.0, 0.0, 0.0, 0.0);
float4 result_2_0 = float4(0.0, 0.0, 0.0, 0.0);
float4 result_3_0 = float4(0.0, 0.0, 0.0, 0.0);
for (uint k = 0; k < 256; k++) {
float4 a_0_0 = A[(m * 4 + 0) * 256 + (k * 1 + 0)];
float4 a_1_0 = A[(m * 4 + 1) * 256 + (k * 1 + 0)];
float4 a_2_0 = A[(m * 4 + 2) * 256 + (k * 1 + 0)];
float4 a_3_0 = A[(m * 4 + 3) * 256 + (k * 1 + 0)];
float4 b_0_0 = B[(k * 4 + 0) * 256 + (n * 1 + 0)];
float4 b_0_1 = B[(k * 4 + 1) * 256 + (n * 1 + 0)];
float4 b_0_2 = B[(k * 4 + 2) * 256 + (n * 1 + 0)];
float4 b_0_3 = B[(k * 4 + 3) * 256 + (n * 1 + 0)];
result_0_0 += mul(a_0_0.x, b_0_0);
result_1_0 += mul(a_1_0.x, b_0_0);
result_2_0 += mul(a_2_0.x, b_0_0);
result_3_0 += mul(a_3_0.x, b_0_0);
result_0_0 += mul(a_0_0.y, b_0_1);
result_1_0 += mul(a_1_0.y, b_0_1);
result_2_0 += mul(a_2_0.y, b_0_1);
result_3_0 += mul(a_3_0.y, b_0_1);
result_0_0 += mul(a_0_0.z, b_0_2);
result_1_0 += mul(a_1_0.z, b_0_2);
result_2_0 += mul(a_2_0.z, b_0_2);
result_3_0 += mul(a_3_0.z, b_0_2);
result_0_0 += mul(a_0_0.w, b_0_3);
result_1_0 += mul(a_1_0.w, b_0_3);
result_2_0 += mul(a_2_0.w, b_0_3);
result_3_0 += mul(a_3_0.w, b_0_3);
}
C[(m * 4 + 0) * 256 + (n * 1 + 0)] = result_0_0;
C[(m * 4 + 1) * 256 + (n * 1 + 0)] = result_1_0;
C[(m * 4 + 2) * 256 + (n * 1 + 0)] = result_2_0;
C[(m * 4 + 3) * 256 + (n * 1 + 0)] = result_3_0;
}
dispatch params: 128,32,1
Clearly more can be done to tune it
(such as factoring out the K
dimension a bit more or doing more levels of tiling),
but I'm quite happy with the results.
Hitting nearly 1TFlops in the browser (50% of peak) is extremely empowering and
it's exciting to see such technology available.