-
Notifications
You must be signed in to change notification settings - Fork 633
Enable static type checking with Pyrefly #2136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
87dbb9e
fd5320c
647b0d8
325efd9
f724ebd
16cb3dc
d0a1e2f
cd9484b
e83adcf
003704a
d27eba4
8fa1bf9
1762763
acce4dd
22c9682
50cd8f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,9 @@ | ||
| einops | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wwwjn is this and pillow for FLUX? maybe we should put them in core requirements.txt |
||
| expecttest==0.1.6 | ||
| pillow | ||
| pytest==7.3.2 | ||
| pytest-cov | ||
| pre-commit | ||
| pyrefly==0.45.1 | ||
| tomli-w >= 1.1.0 | ||
| transformers | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,8 @@ possible. Contributions should follow the [Contributing Guidelines](#contributin | |
|
|
||
| ### Setup | ||
| ``` | ||
| pip install -r requirements-dev.txt | ||
| pip install -r requirements.txt -r requirements-dev.txt | ||
| pip install --force-reinstall --pre --index-url https://download.pytorch.org/whl/nightly/cu126 torch | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove this for two reasons:
We can add a link to that section if you feel necessary |
||
| ``` | ||
|
|
||
| ### Pull Requests | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,3 +62,7 @@ include = ["torchtitan*"] | |
| [tool.pytest.ini_options] | ||
| addopts = ["--showlocals"] # show local variables in tracebacks | ||
| testpaths = ["tests"] | ||
|
|
||
| [tool.pyrefly] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can I ask why we'd put this in pyproject? |
||
| project-excludes = ["torchtitan/experiments", "**/tests/**"] | ||
| ignore-missing-imports = ["torchao.*", "torchft"] # optional dependencies | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,16 +16,16 @@ | |
|
|
||
| @torch.inference_mode() | ||
| def convert_from_hf(input_dir, output_dir, model_name, model_flavor): | ||
| if model_name == "flux": | ||
| import torchtitan.experiments.flux # noqa: F401 | ||
| # initialize model to allocate memory for state dict | ||
| train_spec = train_spec_module.get_train_spec(model_name) | ||
| model_args = train_spec.model_args[model_flavor] | ||
|
|
||
| with torch.device("cpu"): | ||
| model = train_spec.model_cls(model_args) | ||
| # pyrefly: ignore [bad-argument-type] | ||
| model = ModelWrapper(model) | ||
|
|
||
| # pyrefly: ignore [not-callable] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I feel like we need to fix/refactor the code if it doesn't pass pyrefly check, instead of leaving a lot of comments to suppress the error
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wwwjn Adding ignores to get the code to a clean state and then incrementally fixing them is the standard way to enable a type checker on a large code base. I already made fixes to remove 100+
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wwwjn It's actually our responsibility to fix the typing as we understand the code. We should have follow-up BE PRs to address these ignore.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, let's do some BE work later |
||
| sd_adapter = train_spec.state_dict_adapter(model_args, None) | ||
| assert ( | ||
| sd_adapter is not None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,8 +30,10 @@ def convert_to_hf( | |
|
|
||
| with torch.device("cpu"): | ||
| model = train_spec.model_cls(model_args) | ||
| # pyrefly: ignore [bad-argument-type] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO for me: Change ModelProtocol to include nn.Module as the base class.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fegin Would you like me to put these TODOs in the code?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need, we should go through these ignore later one by one anyway, thanks! |
||
| model = ModelWrapper(model) | ||
|
|
||
| # pyrefly: ignore [not-callable] | ||
| sd_adapter = train_spec.state_dict_adapter(model_args, hf_assets_path) | ||
| assert ( | ||
| sd_adapter is not None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ def loss_fn(logits1, logits2): | |
| probs2 = F.softmax(logits2, dim=-1) | ||
|
|
||
| # Calculate KL Divergence | ||
| kl_loss = F.kl_div(probs1, probs2, "mean") | ||
| kl_loss = F.kl_div(probs1, probs2, reduction="mean") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is interesting, so what's the error pyrefly reports without this change?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the error: |
||
| return kl_loss | ||
|
|
||
|
|
||
|
|
@@ -75,10 +75,13 @@ def forward_tt(config_path, checkpoint_path, test_set): | |
|
|
||
| # materalize model | ||
| device = torch.device(device_type) | ||
| # pyrefly: ignore [missing-attribute] | ||
| model.to_empty(device=device) | ||
| model.init_weights(buffer_device=device) | ||
| # pyrefly: ignore [missing-attribute] | ||
| model.eval() | ||
|
|
||
| # pyrefly: ignore [bad-argument-type] | ||
| modelWrapper = ModelWrapper(model) | ||
| state_dict = modelWrapper._get_state_dict() | ||
|
|
||
|
|
@@ -94,6 +97,7 @@ def forward_tt(config_path, checkpoint_path, test_set): | |
| input_ids = input_ids.unsqueeze(0) | ||
|
|
||
| # obtains the logits of only the last token in the predictions | ||
| # pyrefly: ignore [not-callable] | ||
| predictions = model(input_ids)[:, -1, :].unsqueeze(1) | ||
| output_list.append(predictions) | ||
|
|
||
|
|
@@ -120,6 +124,7 @@ def forward_tt(config_path, checkpoint_path, test_set): | |
| config_manager = ConfigManager() | ||
| config = config_manager.parse_args([f"--job.config_file={config_path}"]) | ||
| train_spec = get_train_spec(config.model.name) | ||
| # pyrefly: ignore [not-callable] | ||
| tokenizer = train_spec.build_tokenizer_fn(config) | ||
|
|
||
| # Build test set of randomly generated token ids | ||
|
|
@@ -150,10 +155,11 @@ def forward_tt(config_path, checkpoint_path, test_set): | |
| avg_losses = {} | ||
|
|
||
| for test_name, (baseline_outputs, conversion_outputs) in test_configs.items(): | ||
| total_loss = 0 | ||
| total_loss: int | torch.Tensor = 0 | ||
| for baseline, outputs in zip(baseline_outputs, conversion_outputs): | ||
| total_loss += loss_fn(baseline, outputs) | ||
| avg_loss = total_loss / len(test_set) | ||
| # pyrefly: ignore [missing-attribute] | ||
| avg_losses[test_name] = avg_loss.item() | ||
|
|
||
| for test_name, avg_loss in avg_losses.items(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,25 +98,33 @@ def estimate_memory(job_config: JobConfig): | |
|
|
||
| # Build the collection of model converters. No-op if `model.converters` empty | ||
| model_converters = build_model_converters(job_config, parallel_dims) | ||
| # pyrefly: ignore [bad-argument-type] | ||
| model_converters.convert(model) | ||
|
|
||
| # apply PT-D DP/TP parallelisms and activation checkpointing | ||
| train_spec.parallelize_fn(model, parallel_dims, job_config) | ||
|
|
||
| # pyrefly: ignore [missing-attribute] | ||
| model.to_empty(device="cuda") | ||
| if not active_fake_mode(): | ||
| model.init_weights() | ||
| # pyrefly: ignore [missing-attribute] | ||
| model.train() | ||
|
|
||
| # build optimizer after applying parallelisms to the model | ||
| # pyrefly: ignore [bad-argument-type] | ||
| optimizers = build_optimizers([model], job_config.optimizer, parallel_dims) | ||
| lr_schedulers = build_lr_schedulers( | ||
| optimizers.optimizers, job_config.lr_scheduler, job_config.training.steps | ||
| # pyrefly: ignore [bad-argument-type] | ||
| optimizers.optimizers, | ||
| job_config.lr_scheduler, | ||
|
Comment on lines
+119
to
+120
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO for me: |
||
| job_config.training.steps, | ||
| ) | ||
| # Post optimizer step model converters hook. | ||
| # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 | ||
| # where it issues a single all-reduce for all parameters at once for better performance | ||
| optimizers.register_step_post_hook( | ||
| # pyrefly: ignore [bad-argument-type] | ||
| lambda *args, **kwargs: model_converters.post_optimizer_hook(model) | ||
| ) | ||
|
|
||
|
|
@@ -136,6 +144,7 @@ def estimate_memory(job_config: JobConfig): | |
| device="cuda", | ||
| ), | ||
| ) | ||
| # pyrefly: ignore [bad-argument-type] | ||
| fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0]) | ||
| fsdp_memtracker.track_inputs(batch) | ||
|
|
||
|
|
@@ -145,14 +154,18 @@ def estimate_memory(job_config: JobConfig): | |
| input_ids, labels = batch | ||
| # train step | ||
| with train_context(): | ||
| # pyrefly: ignore [not-callable] | ||
| pred = model(input_ids) | ||
| loss = loss_fn(pred, labels) | ||
| del pred | ||
| loss.backward() | ||
|
|
||
| # clip gradients | ||
| torch.nn.utils.clip_grad_norm_( | ||
| model.parameters(), job_config.training.max_norm, foreach=True | ||
| # pyrefly: ignore [missing-attribute] | ||
| model.parameters(), | ||
| job_config.training.max_norm, | ||
| foreach=True, | ||
| ) | ||
| # optimizer step | ||
| optimizers.step() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.