-
Notifications
You must be signed in to change notification settings - Fork 1
Description
We would like to run our experiments on TPUs. To get this to happen we need to use Pytorch/XLA. This involves at least the following basics:
-
Python 3.8: Pytorch/XLA runs only on Python 3.8. A small number of our dependencies and code requires Python 3.9+. We should identify and work around these. See also devinterp issue 11.
-
Deterministic randomness: Does XLA have its own TPU-based RNG? Are we seeding it? Look into this and make sure we are getting reproducible runs.
-
Checkpointing: we need to make sure the exported models are able to be loaded properly. We had some trouble with this so it seems we should first move them to CPU then save and store the checkpoint.
-
✅ Basic optimisation: Context: Pytorch/XLA tensors are computed lazily. Tensor operations construct a computational graph, and on demand (or explicit request) the computational graph is compiled (with an optimising compiler) and then executed on the TPU. The compilation step is very expensive, and only pays off if the same computational graph is used repeatedly (such as in each iteration of a training loop) where the compilation can be cached. So, to get baseline performance, we need to do the following:
- make sure each training loop uses the same computational graph
- make sure there are no accidental demand points during the loop
- insert demand points at the end of each iteration
- make sure it's working (i.e. does it go faster than CPU) and adjust as needed
That should lead to ~3x speed up once everything is working on the TPU.
Then there are some pathways to further optimisation that seem low-hanging enough to be worth exploring:
-
✅ Parallelisation across TPUs (10x speed up): Google TPU Research Cloud offers 5 x TPU v2 and 5 x TPU v3. So that's 10 TPUs that can be conducting independent training runs. The challenge here is to efficiently manage sweeps across 10 independent VMs.
- W&B sweeps is the appropriate tool for this. It's basically working.
- Still looking for a way to reliably run experiments on TPU VM while not logged in to SSH.
- Still looking for a convenient way to launch the agents with a single command from local shell.
-
✅ Parallelisation within TPUs: (up to 4x speed up): Each TPU v2-8 or v3-8 actually has four two-core chips (so-called 'devices') that can compute in parallel. In other words, so far we are only using 1/4 of each TPU. Possibilities for doing further parallelisation across the four chips:
- Parallelise across batch: Split batch in four, forward, backward, aggregate, update, repeat. Common approach taken by most code examples I see. Our batches and models are pretty small so this may not be worth the sync overhead.
- Parallelise across runs: Get each of the four devices doing a training run in a separate, non-interacting processes. Seems easier. I don't foresee any serious bottlenecks (I don't think we are hitting anywhere near 25% of CPU/TPU memory limits for a single process; network access for W&B syncing seems unlikely to bottleneck; CPU--TPU communication may bottleneck---hopefully we are compute bound and CPU access falls into a rhythm 🕺).
Stretch goals:
-
Also use preemptible TPUs (11x speed up): Google TPU Research Cloud offers a further 100 free preemptible TPU v2 (Dan clarifies that 'preemptible' means each VM can be killed at any point, lasting up to 24 hours I think, after which point I assume we can spawn new ones).
- We should streamline the process of creating the VMs so that we can easily spin up new TPUs and integrate them into our system from step (5), allowing us to run up to 110 experiments in parallel.
- We should make training more robust to stops and restarts e.g. set it up so we can continue training from the last checkpoint. There is already some code towards this but it needs to be integrated into our system from (5).
These improvements will also be useful for running experiments on non-preemptible VMs, which also sometimes need to be respawned or training resumed from a checkpoint after a crash.
-
More optimisation (uncertain small speed ups): Beyond just 'getting the TPU to run faster than the CPU' for steps (4), there is potentially more room to speed up each training run:
- We should explore the performance using XLA metrics (and perhaps profiling) to see if there are any bottlenecks.
- There may be certain computations, e.g. model and data set initialisation, batch generation, perhaps evaluations, that are not actually worth doing on the TPU either because they are not that much faster or because they are only done once so the compilation doesn't pay off. We should identify these and isolate them from the part of the computation that is compiled to the TPU.
- While the optimising compiler can make our computational graphs better, we can still potentially improve our pre-compiled computational graphs. We previously did some light profiling of attention computation and causal masking methods on GPUs. Worth revisiting for TPUs because the conclusions might be different.