train_matrix_exp_estim_jax.py
# Discussion here: https://news.ycombinator.com/item?id=31031812
# Better (faster) script from Patrick here: https://gist.github.com/patrick-kidger/68bf7b99ba02c246b20eaa38f2ad3d38
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
import time
class MatrixExponentEstimator(eqx.Module):
d0: eqx.nn.Linear
d1: eqx.nn.Linear
d2: eqx.nn.Linear
def __init__(self, rk):
initializer = jax.nn.initializers.glorot_normal()
_, rk = jax.random.split(rk)
self.d0 = eqx.nn.Linear(4, 32, key=rk)
_, rk = jax.random.split(rk)
self.d1 = eqx.nn.Linear(32, 16, key=rk)
_, rk = jax.random.split(rk)
self.d2 = eqx.nn.Linear(16, 4, key=rk)
def __call__(self, x):
x = jax.numpy.tanh(self.d0(x))
x = jax.numpy.tanh(self.d1(x))
return self.d2(x)
def f(x):
return jax.scipy.linalg.expm(x.reshape((2,2))).reshape((4,))
def apply_matrix_exponential(x):
return jax.numpy.apply_along_axis(f, 1, x)
def train():
epochs = 10000
rk = jax.random.PRNGKey(1337)
trainx = jax.random.normal(rk, shape=(10000, 2*2))
trainy = apply_matrix_exponential(trainx)
_, rk = jax.random.split(rk)
testx = jax.random.normal(rk, shape=(10000, 2*2))
testy = apply_matrix_exponential(testx)
model = MatrixExponentEstimator(rk)
adam = optax.adam(1e-3)
opt_state = adam.init(model)
@jax.jit
@jax.value_and_grad
def loss_fn(model, X, y):
err = jax.vmap(model)(X) - y
return jnp.mean(jnp.square(err)) # mse
print('Initial Train Loss: {:.4f}'.format(loss_fn(model, trainx, trainy)[0].item()))
print('Initial Test Loss: {:.4f}'.format(loss_fn(model, testx, testy)[0].item()))
for _ in range(3):
t_start = time.time()
for _ in range(epochs):
loss, grads = loss_fn(model, trainx, trainy)
updates, opt_state = adam.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
print('Took: {:.2f} seconds'.format(time.time() - t_start))
print('Train Loss: {:.4f}'.format(loss_fn(model, trainx, trainy)[0].item()))
print('Test Loss: {:.4f}'.format(loss_fn(model, testx, testy)[0].item()))
if __name__ == '__main__':
train()