When optimizing weights and biases of a model, does it make sense to replace:
for _ in range(epochs):
w, b = step(w, b)
With:
w, b = lax.fori_loop(0, epochs, lambda wb: step(wb[0], wb[1]), (w, b))
If I understand correctly, this means that the entire training process can then be a single compiled JAX function (takes in training data, outputs optimized weights and biases).
Is this a standard approach? What are the tradeoffs to consider?
It's fine to train your model with
fori_loop, particularly for simple models. It may be slightly faster, but in general XLA won't fuse operations across different loop steps. It's also not possible to return early within afori_loopwhen you reach a certain loss threshold (though you could do that withwhile_loopif you wish).For more complicated models, you often will want to do some sort of I/O at every step (e.g. loading new training data, logging fit parameters, etc.) While this is possible to do within
fori_loopviajax.experimental.io_callback, it is somewhat less convenient than doing it directly from the host within a Pythonforloop, so in general users tend to useforloops for their training iterations.