JavaScript has some great ideas in the server space. Of those, async programming and reactive server definitions are potentially very useful in machine learning.
Below are some ideas found in Shumai, an experimental framework to expose machine learning research and training to the server-side JavaScript world.
To motivate this post, let's pretend we've got two teams working on an ML pipeline. One is building the model and the other is getting a dataset together as well as defining a metric for correctness.
Model Team Dataset Team
| |
v v
[ Server ] < ===== > [ Client ]
- weights - dataset
- architecture - loss function
- GPUs - CPUs
Each team owns the hardware they're building for. The model team has big GPUs that can churn through very large machine learning models. The dataset team has very fast CPUs that can generate lots of data. They don't need GPUs. If needed (bandwidth), the hardware can be co-located.
It's also important to note that the teams iterate at different rates. Keeping the problems tightly coupled would slow everything down.
To start, let's assume we have the solution and just need to serve it up.
// server
import * as sm from '@shumai/shumai'
const weight = sm.scalar(1.718)
weight.requires_grad = true
// keeping it trivial for the example
function model(X) {
return X.mul(weight)
}
To expose model this to the network (err, the other team),
we can use sm.io.serve(map, options)
and map the /forward
URL to an
invocation.
// server
sm.io.serve(
{
forward: (user, X) => {
const Y = model(X)
return Y
},
},
{ port: 3000 }
)
And now we can run the model remotely with tfetch
, which is effectively
fetch
but for tensors.
// client
import * as sm from '@shumai/shumai'
const url = 'localhost:3000'
const [input, ref_output] = get_data()
const output = await sm.io.tfetch(`${url}/forward`, input)
That's it. Everything is wrapped up and serialized for you! 🚀
But what if our model is bad? We should tell the server about that.
First we'll need to make the server able to update its model. Theoretically we could do something like this...
sm.io.serve(
// ...
optimize: (user, jacobian) => {
const differentiated_tensors = Y.backward(jacobian)
// update each tensor with a learning rate of 1e-3
sm.optim.sgd(differentiated_tensors, 1e-3)
}
// ...
)
But we don't have Y
available to us. 😢 In fact, every invocation
to forward
is going to change Y
and potentially change the
backward function (if the model wasn't this simple). What can we do?
JavaScript captures to the rescue! We can save an invocation to our backward pass on the user object.
sm.io.serve(
{
forward: (user, tensor) => {
const Y = t.mul(X)
// capture the optimization function and save it on the user object
user.opt = (jacobian) => {
sm.optim.sgd(Y.backward(jacobian), 1e-2)
}
return Y
},
optimize: (user, jacobian) => {
user.opt(jacobian)
}
}
// ...
)
(Technical Note: we're merging backward and optimize steps. This is of course not required, it just makes the example simpler. We could always split up the server to calculate gradients first and then apply the update on a different call.)
Now we need to calculate the Jacobian from the client side.
// client
output.requires_grad = true
const loss = sm.loss.mse(output, ref_output)
loss.backward()
await sm.io.tfetch(`${url}/optimize`, output.grad)
Done.
If you're not a huge fan of the stateful requires_grad
semantics
and would prefer a more functional solution, we can actually pass a
backward pass callback to the original tfetch
:
//client
// this will be called automatically
const backward = async (grad) => {
await sm.io.tfetch(`${url}/optimize`, grad.grad_in)
}
const output = await sm.io.tfetch(`${url}/forward`, input, { grad_fn: backward })
const loss = sm.loss.mse(output, ref_output)
await loss.backward()
This of course will need to be wrapped in a loop.
I like using sm.util.viter
to get visual iteration.
for (const _ of sm.util.viter(200)) {
// ...
}
The code in the post can be found here: client and server.
If you're interested, a more involved model parallel example lives here. The server semantics popular in JavaScript scale quite well to robust multi-node training regimes. Some interesting artifacts of this way of setting up training:
await Promise.all(submodels)
)