# Fully Embracing JavaScript: Machine Learning Training as Async Services

JavaScript has some great ideas in the server space.
Of those, async programming and reactive server definitions are potentially
very useful in machine learning.

Below are some ideas found in [Shumai](https://github.com/facebookresearch/shumai),
an experimental framework to expose machine learning research
and training to the server-side JavaScript world.

## Problem Statement

To motivate this post, let's pretend we've got two teams working
on an ML pipeline.
One is building the model and the other is getting a dataset
together as well as defining a metric for correctness.

```text

   Model Team            Dataset Team
       |                      |
       v                      v
      
  [  Server  ] < ===== > [  Client  ]

   - weights             - dataset
   - architecture        - loss function
   - GPUs                - CPUs


```

Each team owns the hardware they're building for.
The model team has big GPUs that can churn through very large machine learning models.
The dataset team has very fast
CPUs that can generate lots of data.  They don't need GPUs.
If needed (bandwidth), the hardware can be co-located.

It's also important to note that the teams iterate at different rates.
Keeping the problems tightly coupled would slow everything down.

## Inference

To start, let's assume we have the solution and just need to serve it up.

```javascript
// server
import * as sm from '@shumai/shumai'

const weight = sm.scalar(1.718)
weight.requires_grad = true

// keeping it trivial for the example
function model(X) {
  return X.mul(weight)
}
```

To expose model this to the network (err, the other team),
we can use `sm.io.serve(map, options)` and map the `/forward` URL to an
invocation.

```javascript
// server
sm.io.serve(
  {
    forward: (user, X) => {
      const Y = model(X)
      return Y
    },
  },
  { port: 3000 }
)
```

And now we can run the model remotely with `tfetch`, which is effectively
[`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/fetch) but for tensors.

```javascript
// client
import * as sm from '@shumai/shumai'
const url = 'localhost:3000'

const [input, ref_output] = get_data()
const output = await sm.io.tfetch(`${url}/forward`, input)
```

That's it.  Everything is wrapped up and serialized for you! 🚀

## Training

But what if our model is bad?  We should tell the server about that. 

### Server Side

First we'll need to make the server *able* to update its model.
Theoretically we could do something like this...

```javascript
sm.io.serve(
  // ...
    optimize: (user, jacobian) => {
      const differentiated_tensors = Y.backward(jacobian)
      // update each tensor with a learning rate of 1e-3
      sm.optim.sgd(differentiated_tensors, 1e-3)
    }
  // ...
)
```
But we don't have `Y` available to us. 😢  In fact, every invocation
to `forward` is going to change `Y` and potentially change the
backward function (if the model wasn't this simple).  What can we do?

JavaScript captures to the rescue!  We can save an invocation to our backward pass
on the user object.

```javascript
sm.io.serve(
  {
    forward: (user, tensor) => {
      const Y = t.mul(X)
      // capture the optimization function and save it on the user object
      user.opt = (jacobian) => {
        sm.optim.sgd(Y.backward(jacobian), 1e-2)
      }
      return Y
    },
    optimize: (user, jacobian) => {
      user.opt(jacobian)
    }
  }
  // ...
)
```
(*Technical Note: we're merging backward and optimize steps.
This is of course not required, it just makes the example simpler.
We could always split up the server to calculate gradients first and then
apply the update on a different call.*)

### Client Side

Now we need to calculate the Jacobian from the client side.

```javascript
// client
output.requires_grad = true
const loss = sm.loss.mse(output, ref_output)
loss.backward()
await sm.io.tfetch(`${url}/optimize`, output.grad)
```

Done.

If you're not a huge fan of the stateful `requires_grad` semantics
and would prefer a more functional solution, we can actually pass a
backward pass callback to the original `tfetch`:

```javascript
//client

// this will be called automatically
const backward = async (grad) => {
  await sm.io.tfetch(`${url}/optimize`, grad.grad_in)
}

const output = await sm.io.tfetch(`${url}/forward`, input, { grad_fn: backward })

const loss = sm.loss.mse(output, ref_output)
await loss.backward()
```
This of course will need to be wrapped in a loop.
I like using `sm.util.viter` to get visual iteration.

```
for (const _ of sm.util.viter(200)) {
  // ...
}
```
![](https://i.imgur.com/yEKqKpn.gif)

## Code Listing

The code in the post can be found here: [client](https://github.com/facebookresearch/shumai/blob/main/examples/client.ts)
and [server](https://github.com/facebookresearch/shumai/blob/main/examples/serve.ts).

If you're interested, a more involved model parallel example lives [here](https://github.com/facebookresearch/shumai/blob/main/examples/distributed/REAME.md).
The server semantics popular in JavaScript scale quite well to
robust multi-node training regimes.
Some interesting artifacts of this way of setting up training:
- If you have parallel data feeds, pipelining is automatic
- Failed forward execution is graceful (servers stay alive, waiting for requests)
- Parallel calls to sub-models is trivial (`await Promise.all(submodels)`)