You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Running around 8192 envs seems to be the limit. With split sharding across 8 devices, it takes about 1 hour and 27 minutes. If more than 8192 envs are used, there will be memory issues during JIT AOT compilation.
The text was updated successfully, but these errors were encountered:
Given your comment, I did some preliminary experiments and I have confirmed similar results.
I noticed that the throughput on TPUs (v2) was indeed slow on colab, even though the same program scales appropriately on colab A100 GPU.
It may require further investigation. Let's use this issue.
TODOs may include:
Try the latest TPU
Test other games
Find out bottleneck (I guess it's legal action mask computation)
sotetsuk
changed the title
PGX slower on TPUs
Low performance on TPUs
Aug 31, 2024
PGX on TPUs seems to be slower than CPUs.
With a TPU v3-8, PGX is only achieving 1638 steps / sec on the game of chess.
Minimal Reproducible Example
PGX CPU vs TPU Test (512 env) (with sharding)
PGX CPU vs TPU Test (64 env) (single device)
Running around 8192 envs seems to be the limit. With split sharding across 8 devices, it takes about 1 hour and 27 minutes. If more than 8192 envs are used, there will be memory issues during JIT AOT compilation.
The text was updated successfully, but these errors were encountered: