Optimizing Using Jax¶
From Scratch¶
Step Function¶
# STEP FUNCTION
@jax.jit
def step(params, X, y, opt_state):
# calculate loss
loss = mll_loss(params, X, y)
# calculate gradient of loss
grads = dloss(params, X, y)
# update optimizer state
opt_state = opt_update(0, grads, opt_state)
# update params
params = get_params(opt_state)
return params, opt_state, loss
And now we need to actually go through and initialize the parameters.
# TRAINING PARARMETERS
n_epochs = 500 if not args.smoke_test else 2
learning_rate = 0.01
losses = list()
# initialize optimizer
opt_init, opt_update, get_params = optimizers.rmsprop(step_size=learning_rate)
# initialize parameters
opt_state = opt_init(params)
# get initial parameters
params = get_params(opt_state)
And lastly let's do the actual loop.
# initialize progress bar
postfix = {}
with tqdm.trange(n_epochs) as bar:
for i in bar:
# 1 step - optimize function
params, opt_state, value = step(params, X, y, opt_state)
# store loss values
losses.append(value.mean())
# store parameters for display
postfix = {}
for ikey in params.keys():
postfix[ikey] = f"{params[ikey]:.2f}"
postfix["Loss"] = f"{onp.array(losses[-1]):.2f}"
# update progress bar
bar.set_postfix(postfix)