Skip to content

Metal Performance Shader (MPS) Integration #51

@ludsvick

Description

@ludsvick

Hi all, love the project and it being open-sourced!

I tried following along with the setup guide and noticed a bottleneck on my system that was significantly reduced using MPS with PyTorch.

System & Performance

Device - MacBook Pro (2022)
CPU - M2 8-core
GPU - 10 Core w/ Metal 3 support
Memory - 16 GB

Using the command for shape generation, I had an estimated two hours for extracting geometry. After switching to torch.device("mps") in generate.py, I got that time down to about four minutes.

Problem

I would have made a pull request right off the bat with these changes, but it seems as though there is an operation with the KV caches of the attention transformers that isn't implemented within MPS as of yet (index_copy_). It's not too much of a headache to get around, just adding an environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 before running the scripts lets torch defer to CPU for the operation.

The only frustration is that this would need to be added in to each script that could reference the operation before torch is imported, or to have users add the environment variable in their .rc files, which could be a pain to manage/keep track of. Since there is both a command line and code-based implementation, I thought I would create an issue first to figure out the best way forward.

(Side note: I did mention the operation in PyTorch's tracker, so feel free to give it a +1 to help get their attention)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions