Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .ci/docker/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
einops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn is this and pillow for FLUX? maybe we should put them in core requirements.txt

expecttest==0.1.6
pillow
pytest==7.3.2
pytest-cov
pre-commit
pyrefly==0.45.1
tomli-w >= 1.1.0
transformers
3 changes: 2 additions & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ jobs:
run: python -m pip install --upgrade pip
- name: Install lint utilities
run: |
python -m pip install pre-commit
python -m pip install -r requirements.txt -r requirements-dev.txt
python -m pip install --force-reinstall --pre --index-url https://download.pytorch.org/whl/nightly/cu126 torch
pre-commit install-hooks
- name: Get changed files
id: changed-files
Expand Down
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ repos:
types: [text]
additional_dependencies:
- tomli

- repo: https://github.com/facebook/pyrefly-pre-commit
rev: 0.45.1
hooks:
- id: pyrefly-check
name: Pyrefly (type checking)
pass_filenames: false
language: system
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ possible. Contributions should follow the [Contributing Guidelines](#contributin

### Setup
```
pip install -r requirements-dev.txt
pip install -r requirements.txt -r requirements-dev.txt
pip install --force-reinstall --pre --index-url https://download.pytorch.org/whl/nightly/cu126 torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this for two reasons:

  1. it's not the only way to install torch
  2. it's already covered by install tips for using torchtitan in README

We can add a link to that section if you feel necessary

```

### Pull Requests
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ include = ["torchtitan*"]
[tool.pytest.ini_options]
addopts = ["--showlocals"] # show local variables in tracebacks
testpaths = ["tests"]

[tool.pyrefly]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can I ask why we'd put this in pyproject?

project-excludes = ["torchtitan/experiments", "**/tests/**"]
ignore-missing-imports = ["torchao.*", "torchft"] # optional dependencies
4 changes: 2 additions & 2 deletions scripts/checkpoint_conversion/convert_from_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

@torch.inference_mode()
def convert_from_hf(input_dir, output_dir, model_name, model_flavor):
if model_name == "flux":
import torchtitan.experiments.flux # noqa: F401
# initialize model to allocate memory for state dict
train_spec = train_spec_module.get_train_spec(model_name)
model_args = train_spec.model_args[model_flavor]

with torch.device("cpu"):
model = train_spec.model_cls(model_args)
# pyrefly: ignore [bad-argument-type]
model = ModelWrapper(model)

# pyrefly: ignore [not-callable]
Copy link
Contributor

@wwwjn wwwjn Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyrefly: ignore [bad-argument-type]

I feel like we need to fix/refactor the code if it doesn't pass pyrefly check, instead of leaving a lot of comments to suppress the error

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn Adding ignores to get the code to a clean state and then incrementally fixing them is the standard way to enable a type checker on a large code base. I already made fixes to remove 100+ pyrefly: ignores, but the remaining issues look fairly tricky to resolve and IMO would be better tackled as follow-ups rather than trying to combine larger and more risky refactors with enabling the type checker.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn It's actually our responsibility to fix the typing as we understand the code. We should have follow-up BE PRs to address these ignore.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, let's do some BE work later

sd_adapter = train_spec.state_dict_adapter(model_args, None)
assert (
sd_adapter is not None
Expand Down
2 changes: 2 additions & 0 deletions scripts/checkpoint_conversion/convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def convert_to_hf(

with torch.device("cpu"):
model = train_spec.model_cls(model_args)
# pyrefly: ignore [bad-argument-type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for me: Change ModelProtocol to include nn.Module as the base class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin Would you like me to put these TODOs in the code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need, we should go through these ignore later one by one anyway, thanks!

model = ModelWrapper(model)

# pyrefly: ignore [not-callable]
sd_adapter = train_spec.state_dict_adapter(model_args, hf_assets_path)
assert (
sd_adapter is not None
Expand Down
10 changes: 8 additions & 2 deletions scripts/checkpoint_conversion/numerical_tests_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def loss_fn(logits1, logits2):
probs2 = F.softmax(logits2, dim=-1)

# Calculate KL Divergence
kl_loss = F.kl_div(probs1, probs2, "mean")
kl_loss = F.kl_div(probs1, probs2, reduction="mean")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, so what's the error pyrefly reports without this change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the error:

ERROR Argument `Literal['mean']` is not assignable to parameter `size_average` with type `bool | None` in function `torch.nn.functional.kl_div` [bad-argument-type]
  --> scripts/checkpoint_conversion/numerical_tests_example.py:28:40
   |
28 |     kl_loss = F.kl_div(probs1, probs2, "mean")
   |                                        ^^^^^^
   |

return kl_loss


Expand Down Expand Up @@ -75,10 +75,13 @@ def forward_tt(config_path, checkpoint_path, test_set):

# materalize model
device = torch.device(device_type)
# pyrefly: ignore [missing-attribute]
model.to_empty(device=device)
model.init_weights(buffer_device=device)
# pyrefly: ignore [missing-attribute]
model.eval()

# pyrefly: ignore [bad-argument-type]
modelWrapper = ModelWrapper(model)
state_dict = modelWrapper._get_state_dict()

Expand All @@ -94,6 +97,7 @@ def forward_tt(config_path, checkpoint_path, test_set):
input_ids = input_ids.unsqueeze(0)

# obtains the logits of only the last token in the predictions
# pyrefly: ignore [not-callable]
predictions = model(input_ids)[:, -1, :].unsqueeze(1)
output_list.append(predictions)

Expand All @@ -120,6 +124,7 @@ def forward_tt(config_path, checkpoint_path, test_set):
config_manager = ConfigManager()
config = config_manager.parse_args([f"--job.config_file={config_path}"])
train_spec = get_train_spec(config.model.name)
# pyrefly: ignore [not-callable]
tokenizer = train_spec.build_tokenizer_fn(config)

# Build test set of randomly generated token ids
Expand Down Expand Up @@ -150,10 +155,11 @@ def forward_tt(config_path, checkpoint_path, test_set):
avg_losses = {}

for test_name, (baseline_outputs, conversion_outputs) in test_configs.items():
total_loss = 0
total_loss: int | torch.Tensor = 0
for baseline, outputs in zip(baseline_outputs, conversion_outputs):
total_loss += loss_fn(baseline, outputs)
avg_loss = total_loss / len(test_set)
# pyrefly: ignore [missing-attribute]
avg_losses[test_name] = avg_loss.item()

for test_name, avg_loss in avg_losses.items():
Expand Down
1 change: 1 addition & 0 deletions scripts/download_hf_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def should_download(patterns: list[str], filename: str) -> bool:
missed_files = []

# Download files with progress bar
# pyrefly: ignore [bad-context-manager]
with tqdm(total=len(files_found), desc="Downloading files", unit="file") as pbar:
for filename in files_found:
try:
Expand Down
17 changes: 15 additions & 2 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,33 @@ def estimate_memory(job_config: JobConfig):

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
# pyrefly: ignore [bad-argument-type]
model_converters.convert(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
train_spec.parallelize_fn(model, parallel_dims, job_config)

# pyrefly: ignore [missing-attribute]
model.to_empty(device="cuda")
if not active_fake_mode():
model.init_weights()
# pyrefly: ignore [missing-attribute]
model.train()

# build optimizer after applying parallelisms to the model
# pyrefly: ignore [bad-argument-type]
optimizers = build_optimizers([model], job_config.optimizer, parallel_dims)
lr_schedulers = build_lr_schedulers(
optimizers.optimizers, job_config.lr_scheduler, job_config.training.steps
# pyrefly: ignore [bad-argument-type]
optimizers.optimizers,
job_config.lr_scheduler,
Comment on lines +119 to +120
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for me: build_lr_schedulers has an incorrect typing.

job_config.training.steps,
)
# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
# where it issues a single all-reduce for all parameters at once for better performance
optimizers.register_step_post_hook(
# pyrefly: ignore [bad-argument-type]
lambda *args, **kwargs: model_converters.post_optimizer_hook(model)
)

Expand All @@ -136,6 +144,7 @@ def estimate_memory(job_config: JobConfig):
device="cuda",
),
)
# pyrefly: ignore [bad-argument-type]
fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

Expand All @@ -145,14 +154,18 @@ def estimate_memory(job_config: JobConfig):
input_ids, labels = batch
# train step
with train_context():
# pyrefly: ignore [not-callable]
pred = model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
# pyrefly: ignore [missing-attribute]
model.parameters(),
job_config.training.max_norm,
foreach=True,
)
# optimizer step
optimizers.step()
Expand Down
9 changes: 9 additions & 0 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

# pyrefly: ignore [missing-import]
from generate._generation import generate


Expand All @@ -49,6 +50,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh):
},
)

# pyrefly: ignore [missing-attribute]
for _, transformer_block in model.layers.items():
layer_plan = {
"attention.wq": ColwiseParallel(),
Expand All @@ -63,6 +65,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh):
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
# pyrefly: ignore [bad-argument-type]
parallelize_plan=layer_plan,
)

Expand Down Expand Up @@ -95,6 +98,7 @@ def test_generate(
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"{device_type}:{local_rank}")
# pyrefly: ignore [missing-attribute]
device_module.set_device(device)
device_memory_monitor = build_device_memory_monitor()

Expand All @@ -103,6 +107,7 @@ def test_generate(
logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}")

# Tokenizer setup
# pyrefly: ignore [not-callable]
tokenizer = train_spec.build_tokenizer_fn(config)

model_args = train_spec.model_args[config.model.flavor]
Expand Down Expand Up @@ -131,6 +136,7 @@ def test_generate(

# apply_tp (with Sequence Parallel) on unevenly sharded
# sequences would require https://github.com/pytorch/torchtitan/pull/686
# pyrefly: ignore [bad-argument-type]
apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"])

debug_config = DebugConfig(seed=seed, deterministic=deterministic)
Expand All @@ -142,11 +148,14 @@ def test_generate(
)

# materalize model
# pyrefly: ignore [missing-attribute]
model.to_empty(device=device_type)
with torch.no_grad():
model.init_weights()
# pyrefly: ignore [missing-attribute]
model.eval()

# pyrefly: ignore [missing-attribute]
state_dict = model.state_dict()

# Checkpoint Loading
Expand Down
1 change: 1 addition & 0 deletions scripts/loss_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def run_with_realtime_output(cmd: str, logfile: str, env: dict[str, Any]) -> Non
bufsize=1,
)

# pyrefly: ignore [not-iterable]
for line in process.stdout:
print(line, end="")
log_f.write(line)
Expand Down
Loading
Loading