We should revert the train step measurement PR. The correct way to measure this would be to run a few iterations, then jax.block_until_ready, and then divide elapsed time by number of steps.
Since we do not want to have jax.block_until_ready in production code, I propose to implement this in e.g. a draft PR for now.