Unexposed Orchestration Logic in Neural Network Libraries

Many neural network libraries like to split machine learning into two parts: model definition and everything else. I think this is limiting product development.

The frontends have gotten a lot better

Back in the day, model definition used declarative programming. The popular frameworks were TensorFlow or Caffe and the code was effectively just a datastructure. Caffe models actually were datastructures.

layer {
  name: "conv1"
  type: "Convolution"
  param { lr_mult: 1 }
  param { lr_mult: 2 }
  convolution_param {
    num_output: 20
    kernel_size: 5
    stride: 1
    weight_filler {
      type: "xavier"
    }
    bias_filler {
      type: "constant"
    }
  }
  bottom: "data"
  top: "conv1"
}

It's easy to say "yuck" now that we've seen the beauty of PyTorch and TensorFlow 2.0, which adopt an imperative style.

class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.conv1 = nn.Conv2d(1, 32, 3, 1)

    def forward(self, x):
      x = self.conv1(x)

Clearly, great strides have been made on the frontend. But it's still pretty hard to understand and manipulate how these models are run.

Is it running on a specific GPU? Is it multi-threaded? Is it using all the cores? Can I limit that? Can I run it on my neural processor? Can I hook into the operator dispatch mechanism? 🤷

Why even use a high level library if you care about the nitty-gritty?

I really appreciate the beauty of an API that runs convolution efficiently and doesn't require the boilerplate found in libraries like oneDNN and cuDNN. I like that I don't have to think about aligned mallocs or if my CPU supports AVX512.

But I'm not a researcher. I'm just a regular engineer.

I mess around with learning rates, architecture size and then either train or fine tune models with my own data. Then I try to build a product around that.

And when I'm building the product I really start to feel how underserved my usecase is.

What does a product need that isn't well served by multi-hundred megabyte Python libaries?

There are obvious annoyances. These libraries are huge and Python is really hard to ship because it is slow and brittle to work with.

There have been some improvements in this regard. ONNX, TorchScript and TF allow model export (if it fits in the well defined limits these solutions have). The experience is kind of like exporting a word document to JPEG.

But there are some other things I also find myself yearning for. I call these things "orchestration" because they have to do with how the models are executed.

For example, I might want to run certain operations on certain threads.

Or I may want to ensure that two models running simultaneously don't overlap in their usage of the matmul accelerator on the chip.

Or I might want to lazily execute a block of operations to give the compiler a fighting chance of optimzing things.

And every so often I may want to jump onto a GPU and then back to a CPU and then over to an accelerator.

And generally I want to do this without having to inject my code into the 100k+ lines of C++ that backs these libraries.

I want the comfy Python experience I've grown acustommed to!

Python is too slow for this

It'd be cool if PyTorch and TensorFlow made it easier to control these things. Perhaps an nn.Executor I could inherit from (like nn.Module)?

But there's a bigger issue. If I write my fancy orchestration logic in Python there's still no chance I can ship it.

The interpreter is just too generic and slow. Python is certainly fine for database requests that take on the order of milliseconds, which is around how long it might take a full neural network to run. But if we're orchestrating the op-to-op execution model, we've jumped down into microsecond-latency range and will certainly need a faster language.

JavaScript to the rescue?

JavaScript is a really interesting language. Just like Python, it is high-level and blissfully easy to use and iterate with. Just like Python, it has a ton of adoption. There's a large ecosystem of libraries, helpful Q&A banks and tutorials. Just like Python, you can easily bind C++ native functions to the one of the more popular runtimes, V8.

Unlike Python, the most popular runtimes have built-in JIT compilers that make it a lot faster. On my machine, fib(30) runs in 202ms in Python3.8 and only 11ms with node.js 14. 20x faster.

In my opinion, this reason alone makes JavaScript a prime candidate for the next big machine learning framework.

However, if JavaScript genuinely does become a primary lanuage in machine learning workflows, there is a big cherry on top.

It can run in the browser natively.

But browser JS !== NodeJS

That's true, but even without native C++ bindings, there's been a lot of really cool work done with WASM and WebGL.

My favorite example of in-browser machine learning is Google's FaceMesh demo which is, to me, a shocking demonstration of how fast browser JS has become (and a testament to the TensorFlow team's work).

Conclusion

I'm very grateful for all the work that has gone into neural network libraries. These libraries deal with many complex problems under the hood: device selection, threading, memory planning etc.

The APIs that expose knobs to these problems are reminiscant of the early declarative programming style that we happily abandoned years ago.

I prefer not to think in terms of "configurations" and would love to see a framework that empowers product driven use-cases. And ideally, it'd be in JavaScript.

🙂