Fully Embracing JavaScript: Machine Learning Training as Async Services

JavaScript has some great ideas in the server space. Of those, async programming and reactive server definitions are potentially very useful in machine learning.

Below are some ideas found in Shumai, an experimental framework to expose machine learning research and training to the server-side JavaScript world.

Problem Statement

To motivate this post, let's pretend we've got two teams working on an ML pipeline. One is building the model and the other is getting a dataset together as well as defining a metric for correctness.


   Model Team            Dataset Team
       |                      |
       v                      v
      
  [  Server  ] < ===== > [  Client  ]

   - weights             - dataset
   - architecture        - loss function
   - GPUs                - CPUs

Each team owns the hardware they're building for. The model team has big GPUs that can churn through very large machine learning models. The dataset team has very fast CPUs that can generate lots of data. They don't need GPUs. If needed (bandwidth), the hardware can be co-located.

It's also important to note that the teams iterate at different rates. Keeping the problems tightly coupled would slow everything down.

Inference

To start, let's assume we have the solution and just need to serve it up.

// server
import * as sm from '@shumai/shumai'

const weight = sm.scalar(1.718)
weight.requires_grad = true

// keeping it trivial for the example
function model(X) {
  return X.mul(weight)
}

To expose model this to the network (err, the other team), we can use sm.io.serve(map, options) and map the /forward URL to an invocation.

// server
sm.io.serve(
  {
    forward: (user, X) => {
      const Y = model(X)
      return Y
    },
  },
  { port: 3000 }
)

And now we can run the model remotely with tfetch, which is effectively fetch but for tensors.

// client
import * as sm from '@shumai/shumai'
const url = 'localhost:3000'

const [input, ref_output] = get_data()
const output = await sm.io.tfetch(`${url}/forward`, input)

That's it. Everything is wrapped up and serialized for you! 🚀

Training

But what if our model is bad? We should tell the server about that.

Server Side

First we'll need to make the server able to update its model. Theoretically we could do something like this...

sm.io.serve(
  // ...
    optimize: (user, jacobian) => {
      const differentiated_tensors = Y.backward(jacobian)
      // update each tensor with a learning rate of 1e-3
      sm.optim.sgd(differentiated_tensors, 1e-3)
    }
  // ...
)

But we don't have Y available to us. 😢 In fact, every invocation to forward is going to change Y and potentially change the backward function (if the model wasn't this simple). What can we do?

JavaScript captures to the rescue! We can save an invocation to our backward pass on the user object.

sm.io.serve(
  {
    forward: (user, tensor) => {
      const Y = t.mul(X)
      // capture the optimization function and save it on the user object
      user.opt = (jacobian) => {
        sm.optim.sgd(Y.backward(jacobian), 1e-2)
      }
      return Y
    },
    optimize: (user, jacobian) => {
      user.opt(jacobian)
    }
  }
  // ...
)

(Technical Note: we're merging backward and optimize steps. This is of course not required, it just makes the example simpler. We could always split up the server to calculate gradients first and then apply the update on a different call.)

Client Side

Now we need to calculate the Jacobian from the client side.

// client
output.requires_grad = true
const loss = sm.loss.mse(output, ref_output)
loss.backward()
await sm.io.tfetch(`${url}/optimize`, output.grad)

Done.

If you're not a huge fan of the stateful requires_grad semantics and would prefer a more functional solution, we can actually pass a backward pass callback to the original tfetch:

//client

// this will be called automatically
const backward = async (grad) => {
  await sm.io.tfetch(`${url}/optimize`, grad.grad_in)
}

const output = await sm.io.tfetch(`${url}/forward`, input, { grad_fn: backward })

const loss = sm.loss.mse(output, ref_output)
await loss.backward()

This of course will need to be wrapped in a loop. I like using sm.util.viter to get visual iteration.

for (const _ of sm.util.viter(200)) {
  // ...
}

Code Listing

The code in the post can be found here: client and server.

If you're interested, a more involved model parallel example lives here. The server semantics popular in JavaScript scale quite well to robust multi-node training regimes. Some interesting artifacts of this way of setting up training: