train_matrix_exp_estim_jax.py

# Discussion here: https://news.ycombinator.com/item?id=31031812
# Better (faster) script from Patrick here: https://gist.github.com/patrick-kidger/68bf7b99ba02c246b20eaa38f2ad3d38
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
import time

class MatrixExponentEstimator(eqx.Module):
    d0: eqx.nn.Linear
    d1: eqx.nn.Linear
    d2: eqx.nn.Linear

    def __init__(self, rk):
        initializer = jax.nn.initializers.glorot_normal()
        _, rk = jax.random.split(rk)
        self.d0 = eqx.nn.Linear(4, 32, key=rk)
        _, rk = jax.random.split(rk)
        self.d1 = eqx.nn.Linear(32, 16, key=rk)
        _, rk = jax.random.split(rk)
        self.d2 = eqx.nn.Linear(16, 4, key=rk)

    def __call__(self, x):
        x = jax.numpy.tanh(self.d0(x))
        x = jax.numpy.tanh(self.d1(x))
        return self.d2(x)

def f(x):
    return jax.scipy.linalg.expm(x.reshape((2,2))).reshape((4,))

def apply_matrix_exponential(x):
    return jax.numpy.apply_along_axis(f, 1, x)

def train():
    epochs = 10000
    rk = jax.random.PRNGKey(1337)
    trainx = jax.random.normal(rk, shape=(10000, 2*2))
    trainy = apply_matrix_exponential(trainx)
    _, rk = jax.random.split(rk)
    testx = jax.random.normal(rk, shape=(10000, 2*2))
    testy = apply_matrix_exponential(testx)

    model = MatrixExponentEstimator(rk)
    adam = optax.adam(1e-3)
    opt_state = adam.init(model)

    @jax.jit
    @jax.value_and_grad
    def loss_fn(model, X, y):
        err = jax.vmap(model)(X) - y
        return jnp.mean(jnp.square(err))  # mse
    print('Initial Train Loss: {:.4f}'.format(loss_fn(model, trainx, trainy)[0].item()))
    print('Initial Test Loss: {:.4f}'.format(loss_fn(model, testx, testy)[0].item()))
    for _ in range(3):
        t_start = time.time()
        for _ in range(epochs):
            loss, grads = loss_fn(model, trainx, trainy)
            updates, opt_state = adam.update(grads, opt_state)
            model = eqx.apply_updates(model, updates)
        print('Took: {:.2f} seconds'.format(time.time() - t_start))
        print('Train Loss: {:.4f}'.format(loss_fn(model, trainx, trainy)[0].item()))
        print('Test Loss: {:.4f}'.format(loss_fn(model, testx, testy)[0].item()))

if __name__ == '__main__':
    train()