Skip to content

Conversation

@rchen152
Copy link

@rchen152 rchen152 commented Dec 10, 2025

Enables static type checking of torchtitan with pyrefly. Type checking the code helps catch bugs earlier in the development cycle.

  • Adds pyrefly to CI, as part of the linting workflow.
  • Addresses ~100 type errors that can be fixed via local code changes and updates to type annotations, and silences the rest with # pyrefly: ignore suppression comments. Note that 325efd9 contains all of the non-comment changes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 10, 2025
@rchen152 rchen152 marked this pull request as ready for review December 10, 2025 10:07
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, had some comments.

Comment on lines 10 to 16
uv pip install \
torchft-nightly \
-r .ci/docker/requirements.txt \
-r .ci/docker/requirements-dev.txt \
-r .ci/docker/requirements-flux.txt \
-r .ci/docker/requirements-transformers-modeling-backend.txt \
-r .ci/docker/requirements-vlm.txt \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably shouldn't mix these repo dependency install with type checking install. We should maybe put pyrefly install into .requirements-dev.txt

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. FYI the reason I originally made this a stand-alone script was so that I could install nightly torch and torchao from https://download.pytorch.org/whl/nightly/cu128 in a separate step. With all the dependencies in requirements-dev.txt, we're now installing those deps from pypi, which changes the type checking results a bit. I had to add some new suppressions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up moving torch and torchao out again to a separate install step - all the other deps are in requirements-dev.txt

CONTRIBUTING.md Outdated
6. If you haven't already, complete the Contributor License Agreement ("CLA").
6. Make sure your code type checks:
```
.github/scripts/setup_pyrefly.sh # one-time setup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, should follow pre-commit style

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, pyrefly is now run as a pre-commit hook.

Just FYI, the pyrefly hook works a little differently from the others - you have to install all the dependencies yourself (including pyrefly) before running it. This is because a type checker needs to be able to see a project's dependencies, so we use language: system to use the local environment rather than an isolated environment managed by pre-commit.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very helpful. This pyrefly PR already catches at least one bug. We should definitely enforce stricter type checking.


with torch.device("cpu"):
model = train_spec.model_cls(model_args)
# pyrefly: ignore [bad-argument-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for me: Change ModelProtocol to include nn.Module as the base class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin Would you like me to put these TODOs in the code?


# Calculate KL Divergence
kl_loss = F.kl_div(probs1, probs2, "mean")
kl_loss = F.kl_div(probs1, probs2, reduction="mean")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, so what's the error pyrefly reports without this change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the error:

ERROR Argument `Literal['mean']` is not assignable to parameter `size_average` with type `bool | None` in function `torch.nn.functional.kl_div` [bad-argument-type]
  --> scripts/checkpoint_conversion/numerical_tests_example.py:28:40
   |
28 |     kl_loss = F.kl_div(probs1, probs2, "mean")
   |                                        ^^^^^^
   |

Comment on lines +119 to +120
optimizers.optimizers,
job_config.lr_scheduler,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for me: build_lr_schedulers has an incorrect typing.

state_dict,
storage_writer=storage_writer,
checkpoint_id=checkpoint_save_id,
# pyrefly: ignore [bad-argument-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for me: why self.pg has incorrect typing or why async_save doesn't accept the PG typing.

...


# pyrefly: ignore [inconsistent-inheritance]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for me: this seems to be an inconsistency issue between checkpoint Stateful and dataloader StatefulDataLoader, which we may not be able to change.

from typing import Any, Generic, Iterator, TypeVar

import torch
import torch.distributed.tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for me: change the reference of orch.distributed.tensor.DTensor to DTensor.


# pyrefly: ignore [bad-instantiation]
schedule = schedule_class(
# pyrefly: ignore [bad-argument-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing issue of get_schedule_class, cc., @wconstab @H-Huang

"Context Parallel API. Please update to a newer version."
)
) from e

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for me: set_rotate_method has been in the main PyTorch for a while. We should be able to remove this try/except.

pre-commit
pyrefly==0.45.1
tomli-w >= 1.1.0
torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can not install pytorch using pip install, as torchtitan depends on pytorch-nightly: https://github.com/pytorch/torchtitan/blob/refs/heads/main/README.md#nightly-builds

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn I originally had a separate setup_pyrefly.sh script that installed nightly torch and torchao, which I removed based on @tianyu-l's comment that it would be preferable to have all of the dependencies pyrefly needs in requirements-dev.txt. Do you have any suggestions for how to resolve this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I think other dependencies are fine, but for torch and torchao, can we move these 2 package to lint.yaml, and use commands like pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall to install them separately?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

# pyrefly: ignore [bad-argument-type]
model = ModelWrapper(model)

# pyrefly: ignore [not-callable]
Copy link
Contributor

@wwwjn wwwjn Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyrefly: ignore [bad-argument-type]

I feel like we need to fix/refactor the code if it doesn't pass pyrefly check, instead of leaving a lot of comments to suppress the error

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn Adding ignores to get the code to a clean state and then incrementally fixing them is the standard way to enable a type checker on a large code base. I already made fixes to remove 100+ pyrefly: ignores, but the remaining issues look fairly tricky to resolve and IMO would be better tackled as follow-ups rather than trying to combine larger and more risky refactors with enabling the type checker.

@rchen152 rchen152 requested a review from tianyu-l December 10, 2025 20:35
@rchen152
Copy link
Author

Thanks for all the comments! I believe this is ready for another round of review.

pyrefly==0.45.1
tomli-w >= 1.1.0
torchdata >= 0.8.0
torchft-nightly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now we can do similar things to torchao
@fegin we should move torchtft to experiments

run: |
python -m pip install pre-commit
python -m pip install -r requirements-dev.txt
python -m pip install --force-reinstall --pre --index-url https://download.pytorch.org/whl/nightly/cu126 torch torchao
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can user not run pyrefly / pre-commit without installing the optional torchao package?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the pyrefly config so that torchao is no longer required - pyrefly now uses ignore-missing-imports for torchao, so it'll replace the import with Any rather than erroring when it's not present.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants