-
Notifications
You must be signed in to change notification settings - Fork 70
Description
Hi all, love the project and it being open-sourced!
I tried following along with the setup guide and noticed a bottleneck on my system that was significantly reduced using MPS with PyTorch.
System & Performance
Device - MacBook Pro (2022)
CPU - M2 8-core
GPU - 10 Core w/ Metal 3 support
Memory - 16 GB
Using the command for shape generation, I had an estimated two hours for extracting geometry. After switching to torch.device("mps")
in generate.py
, I got that time down to about four minutes.
Problem
I would have made a pull request right off the bat with these changes, but it seems as though there is an operation with the KV caches of the attention transformers that isn't implemented within MPS as of yet (index_copy_
). It's not too much of a headache to get around, just adding an environment variable PYTORCH_ENABLE_MPS_FALLBACK=1
before running the scripts lets torch defer to CPU for the operation.
The only frustration is that this would need to be added in to each script that could reference the operation before torch is imported, or to have users add the environment variable in their .rc files, which could be a pain to manage/keep track of. Since there is both a command line and code-based implementation, I thought I would create an issue first to figure out the best way forward.
(Side note: I did mention the operation in PyTorch's tracker, so feel free to give it a +1 to help get their attention)