diff --git a/.ci/docker/requirements-dev.txt b/.ci/docker/requirements-dev.txt index 6d53b2f817..0e5a6e491c 100644 --- a/.ci/docker/requirements-dev.txt +++ b/.ci/docker/requirements-dev.txt @@ -2,5 +2,6 @@ expecttest==0.1.6 pytest==7.3.2 pytest-cov pre-commit +pyrefly==0.45.1 tomli-w >= 1.1.0 transformers diff --git a/.ci/docker/requirements-flux.txt b/.ci/docker/requirements-flux.txt index daefd67ff0..8d6797a36b 100644 --- a/.ci/docker/requirements-flux.txt +++ b/.ci/docker/requirements-flux.txt @@ -1,4 +1,2 @@ transformers>=4.51.1 -einops sentencepiece -pillow diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 89832abe65..b63653bb53 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -9,3 +9,5 @@ tyro tokenizers >= 0.15.0 safetensors psutil +einops +pillow diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 0a3976248f..327b0bec23 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -28,7 +28,8 @@ jobs: run: python -m pip install --upgrade pip - name: Install lint utilities run: | - python -m pip install pre-commit + python -m pip install -r requirements.txt -r requirements-dev.txt + python -m pip install --force-reinstall --pre --index-url https://download.pytorch.org/whl/nightly/cu126 torch pre-commit install-hooks - name: Get changed files id: changed-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc996e5046..6f8542fab4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,3 +61,11 @@ repos: types: [text] additional_dependencies: - tomli + +- repo: https://github.com/facebook/pyrefly-pre-commit + rev: 0.45.1 + hooks: + - id: pyrefly-check + name: Pyrefly (type checking) + pass_filenames: false + language: system diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8de2b9df9d..de6373236a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,7 @@ possible. Contributions should follow the [Contributing Guidelines](#contributin ### Setup ``` -pip install -r requirements-dev.txt +pip install -r requirements.txt -r requirements-dev.txt ``` ### Pull Requests diff --git a/pyproject.toml b/pyproject.toml index efe74d3030..7a3687590c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ "tyro", "tensorboard", "psutil", + "einops", + "pillow", ] dynamic = ["version"] @@ -62,3 +64,7 @@ include = ["torchtitan*"] [tool.pytest.ini_options] addopts = ["--showlocals"] # show local variables in tracebacks testpaths = ["tests"] + +[tool.pyrefly] +project-excludes = ["torchtitan/experiments", "**/tests/**"] +ignore-missing-imports = ["torchao.*", "torchft"] # optional dependencies diff --git a/scripts/checkpoint_conversion/convert_from_hf.py b/scripts/checkpoint_conversion/convert_from_hf.py index fae7eec17b..77bfeddd59 100644 --- a/scripts/checkpoint_conversion/convert_from_hf.py +++ b/scripts/checkpoint_conversion/convert_from_hf.py @@ -16,16 +16,16 @@ @torch.inference_mode() def convert_from_hf(input_dir, output_dir, model_name, model_flavor): - if model_name == "flux": - import torchtitan.experiments.flux # noqa: F401 # initialize model to allocate memory for state dict train_spec = train_spec_module.get_train_spec(model_name) model_args = train_spec.model_args[model_flavor] with torch.device("cpu"): model = train_spec.model_cls(model_args) + # pyrefly: ignore [bad-argument-type] model = ModelWrapper(model) + # pyrefly: ignore [not-callable] sd_adapter = train_spec.state_dict_adapter(model_args, None) assert ( sd_adapter is not None diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index ad13850b82..e68a6d2acc 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -30,8 +30,10 @@ def convert_to_hf( with torch.device("cpu"): model = train_spec.model_cls(model_args) + # pyrefly: ignore [bad-argument-type] model = ModelWrapper(model) + # pyrefly: ignore [not-callable] sd_adapter = train_spec.state_dict_adapter(model_args, hf_assets_path) assert ( sd_adapter is not None diff --git a/scripts/checkpoint_conversion/numerical_tests_example.py b/scripts/checkpoint_conversion/numerical_tests_example.py index 66eff8054e..f52851ef9b 100644 --- a/scripts/checkpoint_conversion/numerical_tests_example.py +++ b/scripts/checkpoint_conversion/numerical_tests_example.py @@ -25,7 +25,7 @@ def loss_fn(logits1, logits2): probs2 = F.softmax(logits2, dim=-1) # Calculate KL Divergence - kl_loss = F.kl_div(probs1, probs2, "mean") + kl_loss = F.kl_div(probs1, probs2, reduction="mean") return kl_loss @@ -75,10 +75,13 @@ def forward_tt(config_path, checkpoint_path, test_set): # materalize model device = torch.device(device_type) + # pyrefly: ignore [missing-attribute] model.to_empty(device=device) model.init_weights(buffer_device=device) + # pyrefly: ignore [missing-attribute] model.eval() + # pyrefly: ignore [bad-argument-type] modelWrapper = ModelWrapper(model) state_dict = modelWrapper._get_state_dict() @@ -94,6 +97,7 @@ def forward_tt(config_path, checkpoint_path, test_set): input_ids = input_ids.unsqueeze(0) # obtains the logits of only the last token in the predictions + # pyrefly: ignore [not-callable] predictions = model(input_ids)[:, -1, :].unsqueeze(1) output_list.append(predictions) @@ -120,6 +124,7 @@ def forward_tt(config_path, checkpoint_path, test_set): config_manager = ConfigManager() config = config_manager.parse_args([f"--job.config_file={config_path}"]) train_spec = get_train_spec(config.model.name) + # pyrefly: ignore [not-callable] tokenizer = train_spec.build_tokenizer_fn(config) # Build test set of randomly generated token ids @@ -150,10 +155,11 @@ def forward_tt(config_path, checkpoint_path, test_set): avg_losses = {} for test_name, (baseline_outputs, conversion_outputs) in test_configs.items(): - total_loss = 0 + total_loss: int | torch.Tensor = 0 for baseline, outputs in zip(baseline_outputs, conversion_outputs): total_loss += loss_fn(baseline, outputs) avg_loss = total_loss / len(test_set) + # pyrefly: ignore [missing-attribute] avg_losses[test_name] = avg_loss.item() for test_name, avg_loss in avg_losses.items(): diff --git a/scripts/download_hf_assets.py b/scripts/download_hf_assets.py index e1092b2d70..dbe8ba98b6 100644 --- a/scripts/download_hf_assets.py +++ b/scripts/download_hf_assets.py @@ -167,6 +167,7 @@ def should_download(patterns: list[str], filename: str) -> bool: missed_files = [] # Download files with progress bar + # pyrefly: ignore [bad-context-manager] with tqdm(total=len(files_found), desc="Downloading files", unit="file") as pbar: for filename in files_found: try: diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index e0a752d545..bfa9dddfd2 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -98,44 +98,58 @@ def estimate_memory(job_config: JobConfig): # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) + # pyrefly: ignore [bad-argument-type] model_converters.convert(model) # apply PT-D DP/TP parallelisms and activation checkpointing train_spec.parallelize_fn(model, parallel_dims, job_config) + # pyrefly: ignore [missing-attribute] model.to_empty(device="cuda") if not active_fake_mode(): model.init_weights() + # pyrefly: ignore [missing-attribute] model.train() # build optimizer after applying parallelisms to the model + # pyrefly: ignore [bad-argument-type] optimizers = build_optimizers([model], job_config.optimizer, parallel_dims) lr_schedulers = build_lr_schedulers( - optimizers.optimizers, job_config.lr_scheduler, job_config.training.steps + # pyrefly: ignore [bad-argument-type] + optimizers.optimizers, + job_config.lr_scheduler, + job_config.training.steps, ) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 # where it issues a single all-reduce for all parameters at once for better performance optimizers.register_step_post_hook( + # pyrefly: ignore [bad-argument-type] lambda *args, **kwargs: model_converters.post_optimizer_hook(model) ) + # pyrefly: ignore [missing-attribute] logger.info(f"Vocab size: {model_args.vocab_size}") # Create a dummy batch instead of loading from a dataset batch = ( torch.randint( 0, + # pyrefly: ignore [missing-attribute] model_args.vocab_size, + # pyrefly: ignore [missing-attribute] (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), torch.randint( 0, + # pyrefly: ignore [missing-attribute] model_args.vocab_size, + # pyrefly: ignore [missing-attribute] (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), ) + # pyrefly: ignore [bad-argument-type] fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0]) fsdp_memtracker.track_inputs(batch) @@ -145,6 +159,7 @@ def estimate_memory(job_config: JobConfig): input_ids, labels = batch # train step with train_context(): + # pyrefly: ignore [not-callable] pred = model(input_ids) loss = loss_fn(pred, labels) del pred @@ -152,7 +167,10 @@ def estimate_memory(job_config: JobConfig): # clip gradients torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True + # pyrefly: ignore [missing-attribute] + model.parameters(), + job_config.training.max_norm, + foreach=True, ) # optimizer step optimizers.step() diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index b1d19ad17f..bff9c2aa7f 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -36,6 +36,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) +# pyrefly: ignore [missing-import] from generate._generation import generate @@ -49,6 +50,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): }, ) + # pyrefly: ignore [missing-attribute] for _, transformer_block in model.layers.items(): layer_plan = { "attention.wq": ColwiseParallel(), @@ -63,6 +65,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): parallelize_module( module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) @@ -95,6 +98,7 @@ def test_generate( world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) device = torch.device(f"{device_type}:{local_rank}") + # pyrefly: ignore [missing-attribute] device_module.set_device(device) device_memory_monitor = build_device_memory_monitor() @@ -103,6 +107,7 @@ def test_generate( logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") # Tokenizer setup + # pyrefly: ignore [not-callable] tokenizer = train_spec.build_tokenizer_fn(config) model_args = train_spec.model_args[config.model.flavor] @@ -131,6 +136,7 @@ def test_generate( # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 + # pyrefly: ignore [bad-argument-type] apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) debug_config = DebugConfig(seed=seed, deterministic=deterministic) @@ -142,11 +148,14 @@ def test_generate( ) # materalize model + # pyrefly: ignore [missing-attribute] model.to_empty(device=device_type) with torch.no_grad(): model.init_weights() + # pyrefly: ignore [missing-attribute] model.eval() + # pyrefly: ignore [missing-attribute] state_dict = model.state_dict() # Checkpoint Loading diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py index 3479875036..e9761458a8 100644 --- a/scripts/loss_compare.py +++ b/scripts/loss_compare.py @@ -134,6 +134,7 @@ def run_with_realtime_output(cmd: str, logfile: str, env: dict[str, Any]) -> Non bufsize=1, ) + # pyrefly: ignore [not-iterable] for line in process.stdout: print(line, end="") log_f.write(line) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 79918d0046..7928f514ba 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -29,7 +29,10 @@ set_model_state_dict, StateDictOptions, ) -from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType +from torch.distributed.checkpoint.state_dict_saver import ( + AsyncCheckpointerType, + AsyncSaveResponse, +) from torch.distributed.checkpoint.stateful import Stateful from torchtitan.components.dataloader import BaseDataLoader @@ -174,6 +177,9 @@ class CheckpointManager: """ + mp_queue_send: queue.Queue + purge_thread: threading.Thread | None + def __init__( self, dataloader: BaseDataLoader | None, @@ -208,12 +214,14 @@ def __init__( ) if self.ft_manager and not self.enable_ft_dataloader_checkpoints: + # pyrefly: ignore [deprecated] logger.warn( "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. " "This means replicas can retrain over the same data multiple times, which can result in overfitting." ) if self.ft_manager: + # pyrefly: ignore [missing-attribute] optimizers.init_cache_state_dict() def state_dict(): @@ -233,7 +241,9 @@ def load_state_dict(state_dict): for k, v in state_dict.items(): self.states[k].load_state_dict(v) + # pyrefly: ignore [missing-attribute] self.ft_manager.set_state_dict_fns(load_state_dict, state_dict) + # pyrefly: ignore [missing-attribute] self.ft_replica_id = ft_manager.replica_id async_mode = checkpoint_config.async_mode.lower() @@ -344,7 +354,7 @@ def dcp_save( async_mode: AsyncMode, enable_garbage_collection: bool = False, to_hf: bool = False, - ) -> Future | None: + ) -> Future | AsyncSaveResponse | None: """Save the checkpoint with dcp. Args: state_dict (dict): The state dict to save. @@ -357,7 +367,7 @@ def dcp_save( Future: The future object if the checkpoint is async, otherwise None. """ - ret: Future | None = None + ret: Future | AsyncSaveResponse | None = None storage_writer: HuggingFaceStorageWriter | None = None checkpoint_save_id: str | None = None @@ -394,6 +404,7 @@ def dcp_save( state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_save_id, + # pyrefly: ignore [bad-argument-type] process_group=self.pg, ) elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: @@ -401,6 +412,7 @@ def dcp_save( state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_save_id, + # pyrefly: ignore [bad-argument-type] process_group=self.pg, async_checkpointer_type=AsyncCheckpointerType.PROCESS, async_stager=self.stager, @@ -412,10 +424,12 @@ def dcp_save( checkpoint_id=checkpoint_save_id, ) + # pyrefly: ignore [missing-attribute] if to_hf and self.sd_adapter.fqn_to_index_mapping: consolidate_safetensors_files_on_every_rank( input_dir=os.path.join(checkpoint_id, "sharded"), output_dir=checkpoint_id, + # pyrefly: ignore [bad-argument-type] fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, num_threads=5, ) @@ -489,7 +503,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None: begin = time.monotonic() if not self.enable_ft_dataloader_checkpoints or ( - self.ft_manager and self.ft_manager.participating_rank() == 0 + self.ft_manager + # pyrefly: ignore [missing-attribute] + and self.ft_manager.participating_rank() == 0 ): logger.info("Saving the checkpoint (or staging if async is enabled).") checkpoint_id = self._create_checkpoint_id(curr_step) @@ -511,7 +527,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None: checkpoint_id=checkpoint_id, async_mode=self.async_mode, ) + # pyrefly: ignore [missing-attribute] self.save_future = result.upload_completion + # pyrefly: ignore [missing-attribute] self.staging_future = result.staging_completion self.staging = True elif self.async_mode == AsyncMode.ASYNC: @@ -537,6 +555,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None: assert self.ft_manager is not None logger.info( "Replica %d doesn't save checkpoint.", + # pyrefly: ignore [missing-attribute] self.ft_manager.participating_rank(), ) @@ -589,6 +608,7 @@ def load(self, step: int = -1) -> bool: f"loading from HF safetensors from --checkpoint.initial_load_path: {self.initial_load_path}" ) elif from_hf: + # pyrefly: ignore [missing-attribute] checkpoint_id = self.sd_adapter.hf_assets_path if not os.path.isdir(checkpoint_id): raise ValueError( @@ -596,6 +616,7 @@ def load(self, step: int = -1) -> bool: Either make sure hf_assets_path is correct or provide a valid checkpoint.initial_load_path" ) logger.info( + # pyrefly: ignore [missing-attribute] f"loading HF safetensors from --model.hf_assets_path: {self.sd_adapter.hf_assets_path}" ) else: @@ -644,6 +665,7 @@ def maybe_wait_for_staging(self) -> None: with ``async_checkpoint_with_pinned_memory``. """ if self.enable_staging and self.staging: + # pyrefly: ignore [missing-attribute] self.staging_future.result() self.staging = False @@ -828,6 +850,7 @@ def _purge_stale_checkpoints(self): and os.path.isdir(self.folder) and ( not self.enable_ft_dataloader_checkpoints + # pyrefly: ignore [missing-attribute] or (self.ft_manager and self.ft_manager.participating_rank() == 0) ) ): diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index 071af84d54..7a1c1fcad6 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -41,6 +41,7 @@ def __iter__(self): ... +# pyrefly: ignore [inconsistent-inheritance] class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): """Dataloader that is aware of distributed data parallelism. @@ -58,7 +59,7 @@ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): dp_rank: int dp_world_size: int - batch_size: int + batch_size: int | None def __init__( self, diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index 5d64d34b09..d95470c47d 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import importlib +import importlib.util from contextlib import nullcontext from datetime import timedelta from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union @@ -165,4 +165,5 @@ def maybe_semi_sync_training( raise ValueError( f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported." ) + # pyrefly: ignore [no-matching-overload] return nullcontext() diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 6384feb641..15a3fc6bd1 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -176,6 +176,8 @@ def linear_warmup_stable_decay( curr_adjustment = 1 - math.sqrt(progress) elif lr_decay_type == "cosine": curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) + else: + raise ValueError(f"Unknown lr_decay_type: {lr_decay_type}") curr_adjustment = min_lr_factor + (1 - min_lr_factor) * curr_adjustment return curr_adjustment diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 6905fb5b53..6f50337473 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -40,15 +40,21 @@ class DeviceMemoryMonitor: def __init__(self, device: str = f"{device_type}:0"): + # pyrefly: ignore [read-only] self.device = torch.device(device) # device object + # pyrefly: ignore [missing-attribute] self.device_name = device_module.get_device_name(self.device) + # pyrefly: ignore [missing-attribute] self.device_index = device_module.current_device() + # pyrefly: ignore [missing-attribute] self.device_capacity = device_module.get_device_properties( self.device ).total_memory self.device_capacity_gib = self._to_gib(self.device_capacity) + # pyrefly: ignore [missing-attribute] device_module.reset_peak_memory_stats() + # pyrefly: ignore [missing-attribute] device_module.empty_cache() def _to_gib(self, memory_in_bytes): @@ -61,6 +67,7 @@ def _to_pct(self, memory): return 100 * memory / self.device_capacity def get_peak_stats(self): + # pyrefly: ignore [missing-attribute] device_info = device_module.memory_stats(self.device) max_active = device_info.get("active_bytes.all.peak", -1) @@ -91,6 +98,7 @@ def get_peak_stats(self): ) def reset_peak_stats(self): + # pyrefly: ignore [missing-attribute] device_module.reset_peak_memory_stats() @@ -341,7 +349,7 @@ class MetricsProcessor: device_memory_monitor: DeviceMemoryMonitor color: utils.NoColor | utils.Color - gpu_peak_flops: int + gpu_peak_flops: float ntokens_since_last_log: int data_loading_times: list[float] time_last_log: float diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 80557366da..2b08142f97 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -8,6 +8,7 @@ from typing import Any, Generic, Iterator, TypeVar import torch +import torch.distributed.tensor import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointImpl from torch.distributed.checkpoint.state_dict import ( @@ -88,6 +89,7 @@ def __iter__(self) -> Iterator[T]: def __len__(self) -> int: return len(self.optimizers) + # pyrefly: ignore [bad-override] def step(self, *args, **kwargs) -> None: for optimizer in self.optimizers: optimizer.step(*args, **kwargs) @@ -170,9 +172,11 @@ def optim_hook(param) -> None: ) self._post_init(all_params, optimizer_kwargs) + # pyrefly: ignore [bad-override] def step(self) -> None: pass + # pyrefly: ignore [bad-override] def zero_grad(self) -> None: pass @@ -343,9 +347,12 @@ def build_optimizers_with_moe_load_balancing( def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if transformer_block.moe_enabled: # Assumption: load_balance_coeff is set universally on all moe blocks. + # pyrefly: ignore [missing-attribute] return bool(transformer_block.moe.load_balance_coeff) return False @@ -364,11 +371,15 @@ def _update_expert_bias( # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue + # pyrefly: ignore [missing-attribute] if transformer_block.moe.load_balance_coeff is None: return + # pyrefly: ignore [missing-attribute] tokens_per_expert = transformer_block.moe.tokens_per_expert if _is_recomputation_enabled(transformer_block): # TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice. @@ -398,9 +409,12 @@ def _update_expert_bias( moe_layer_idx = 0 with torch.no_grad(): for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue + # pyrefly: ignore [missing-attribute] moe = transformer_block.moe tokens_per_expert = tokens_per_expert_by_layer[ diff --git a/torchtitan/components/quantization/__init__.py b/torchtitan/components/quantization/__init__.py index de94c37b3e..49faf60733 100644 --- a/torchtitan/components/quantization/__init__.py +++ b/torchtitan/components/quantization/__init__.py @@ -42,7 +42,7 @@ def _validate(job_config: JobConfig): # quantization converter format: # `quantize.[linear | grouped_mm].[float8 | mx]` quantization_type = lambda converter: converter.split(".")[-1] - existing_quantization_converter = None + existing_quantization_converter: str | None = None for converter in job_config.model.converters: if "quantize" in converter: if existing_quantization_converter is None: diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 86932a17bd..9b575876e7 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -6,6 +6,7 @@ from functools import partial import torch +import torch._inductor.config import torch.nn as nn from torchtitan.components.quantization import ( FP8_GROUP_ALIGNMENT_SIZE, diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index a474cc3918..f1c0e09574 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -57,14 +57,19 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): MXLinearConfig as TorchAOMXLinearConfig, ) + # pyrefly: ignore [bad-assignment] mx_job_config: TorchAOMXLinearConfig = job_config.quantize.linear.mx + # pyrefly: ignore [missing-attribute] config = TorchAOMXLinearConfig.from_recipe_name(mx_job_config.recipe_name) + # pyrefly: ignore [missing-attribute] config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[ mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() ] + # pyrefly: ignore [missing-attribute] self.filter_fqns = mx_job_config.filter_fqns self.config = config self.enabled = True + # pyrefly: ignore [missing-attribute] logger.info(f"MX training active with recipe {mx_job_config.recipe_name}") def convert(self, model: nn.Module): diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 022fcbc266..aca2300abe 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -56,6 +56,7 @@ def __init__( # Initialize BOS/EOS token attributes (frequently used) self.bos_id = None + # pyrefly: ignore [bad-assignment] self.eos_id = None self.bos_token = None self.eos_token = None @@ -144,10 +145,13 @@ def _load_tokenizer_from_path(self, tokenizer_path: str) -> Tokenizer: tokenizer = Tokenizer(bpe_model) # Configure GPT-2 style components for proper space handling + # pyrefly: ignore [read-only] tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel( add_prefix_space=False ) + # pyrefly: ignore [read-only] tokenizer.decoder = decoders.ByteLevel() + # pyrefly: ignore [read-only] tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) return tokenizer diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 93fb68a3cc..4673807347 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Generator +from collections.abc import Callable +from contextlib import AbstractContextManager +from typing import TypeAlias import torch import torch.nn as nn @@ -19,6 +21,11 @@ from torchtitan.tools import utils from torchtitan.tools.logging import logger +ValidationContext: TypeAlias = Callable[ + [AbstractContextManager[None] | None], + AbstractContextManager[None], +] + class BaseValidator: def __init__(self, job_config: JobConfig): @@ -52,8 +59,8 @@ def __init__( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, @@ -83,6 +90,7 @@ def __init__( ) @torch.no_grad() + # pyrefly: ignore [bad-override] def validate( self, model_parts: list[nn.Module], @@ -98,6 +106,7 @@ def validate( device_type = utils.device_type num_steps = 0 + # pyrefly: ignore [not-iterable] for input_dict, labels in self.validation_dataloader: if ( self.job_config.validation.steps != -1 @@ -186,8 +195,8 @@ def build_validator( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, @@ -203,6 +212,7 @@ def build_validator( loss_fn=loss_fn, validation_context=validation_context, maybe_enable_amp=maybe_enable_amp, + # pyrefly: ignore [bad-argument-type] metrics_processor=metrics_processor, pp_schedule=pp_schedule, pp_has_first_stage=pp_has_first_stage, diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index 10f4440a4c..79d95c350e 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -16,6 +16,7 @@ try: import tomllib except ModuleNotFoundError: + # pyrefly: ignore [missing-import] import tomli as tomllib from torchtitan.tools.logging import logger @@ -253,7 +254,10 @@ def list_str_rule(type_info: tyro.constructors.PrimitiveTypeInfo): # # ----------------------------------------------------------------------------- + # pyrefly: ignore [missing-import] from rich import print as rprint + + # pyrefly: ignore [missing-import] from rich.pretty import Pretty config_manager = ConfigManager() diff --git a/torchtitan/distributed/__init__.py b/torchtitan/distributed/__init__.py index 63690a660b..f335916595 100644 --- a/torchtitan/distributed/__init__.py +++ b/torchtitan/distributed/__init__.py @@ -65,7 +65,10 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: device_mesh, None, partial( - self._prepare_input_fn, self.input_layout, self.desired_input_layout + # pyrefly: ignore [bad-argument-type] + self._prepare_input_fn, + self.input_layout, + self.desired_input_layout, ), partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 0eecde9052..c0b550a5c1 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -11,6 +11,7 @@ from collections import defaultdict import torch +import torch._functorch.config import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, @@ -221,6 +222,7 @@ def apply_ac( torch._functorch.config.activation_memory_budget = ac_config.memory_budget logger.info(f"Selected {ac_config.memory_budget} budget option") else: + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): transformer_block = _apply_ac_to_transformer_block( transformer_block, @@ -229,6 +231,7 @@ def apply_ac( model_compile_enabled=model_compile_enabled, op_sac_save_list=op_sac_save_list, ) + # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py index 5a4a5d9dd0..5def0e40e6 100644 --- a/torchtitan/distributed/dual_pipe_v.py +++ b/torchtitan/distributed/dual_pipe_v.py @@ -91,7 +91,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, ) @@ -145,6 +147,7 @@ def is_coordination_enabled(self): class SyncHook(torch.autograd.Function): @staticmethod + # pyrefly: ignore [bad-override] def forward(ctx, x, hook_name=""): ctx.hook_name = hook_name # handle edge case for transformer level boundary @@ -158,6 +161,7 @@ def forward(ctx, x, hook_name=""): return x @staticmethod + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): hook_name = ctx.hook_name @@ -260,19 +264,24 @@ def overlap_callback(action: _Action, ctx: _PipelineContext): # Shared container for exception from backward thread def run_backward(): + # pyrefly: ignore [missing-attribute] schedule._assert_unsharded(backward_stage) # Set the backward thread to use the same stream as forward + # pyrefly: ignore [missing-attribute] device_module.set_stream(main_stream) with record_function( f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" ): loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) + # pyrefly: ignore [missing-attribute] schedule.backward_counter[backward_stage_index] += 1 last_backward = ( + # pyrefly: ignore [missing-attribute] schedule.backward_counter[backward_stage_index] == schedule._n_microbatches ) backward_stage.backward_one_chunk( + # pyrefly: ignore [bad-argument-type] backward_mb_index, loss=loss, full_backward=True, @@ -282,14 +291,19 @@ def run_backward(): if backward_is_prev_stage_on_this_rank: stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( backward_stage.get_local_bwd_output(backward_mb_index), + # pyrefly: ignore [bad-argument-type] backward_mb_index, ) def run_forward(): + # pyrefly: ignore [missing-attribute] schedule._assert_unsharded(forward_stage) output = forward_stage.forward_one_chunk( + # pyrefly: ignore [bad-argument-type] forward_mb_index, + # pyrefly: ignore [bad-index, unsupported-operation] arg_mbs[forward_mb_index], + # pyrefly: ignore [bad-index, unsupported-operation] kwarg_mbs[forward_mb_index], ) schedule._maybe_compute_loss( @@ -297,7 +311,9 @@ def run_forward(): ) if forward_is_next_stage_on_this_rank: stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( - output, forward_mb_index + output, + # pyrefly: ignore [bad-argument-type] + forward_mb_index, ) # Run forward and backward in parallel diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 932a7e4aa1..60de27b276 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -81,6 +81,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, self._partition_fn, + # pyrefly: ignore [bad-argument-type] self._prepare_input_fn, ) @@ -184,7 +185,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, ) @@ -210,18 +213,21 @@ def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> N # w1 shape = (experts, out_dim, in_dim) mod.register_parameter( "w1", + # pyrefly: ignore [bad-argument-type] nn.Parameter(distribute_tensor(mod.w1, device_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding # w2 shape = (experts, in_dim, out_dim) mod.register_parameter( "w2", + # pyrefly: ignore [bad-argument-type] nn.Parameter(distribute_tensor(mod.w2, device_mesh, [Shard(0), Shard(2)])), ) # Row-wise sharding # w3 shape = (experts, out_dim, in_dim) mod.register_parameter( "w3", + # pyrefly: ignore [bad-argument-type] nn.Parameter(distribute_tensor(mod.w3, device_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding @@ -234,7 +240,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, ) @@ -296,6 +304,8 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=None, + # pyrefly: ignore [bad-argument-type] input_fn=self._prepare_inputput_fn, + # pyrefly: ignore [bad-argument-type] output_fn=self._prepare_output_fn, ) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 44822039a6..187a363097 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -26,7 +26,7 @@ class ParallelDims: etp: int world_size: int - _world_mesh: DeviceMesh = None + _world_mesh: DeviceMesh | None = None def __post_init__(self): self._validate() @@ -105,7 +105,7 @@ def _build_mesh_with_ep(self) -> DeviceMesh: names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) # Create all the submesh here to ensure all required process groups are # initialized: @@ -156,7 +156,7 @@ def _build_mesh_without_ep(self) -> DeviceMesh: names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + mesh = init_device_mesh(device_type, tuple(dims), mesh_dim_names=tuple(names)) # Create all the submesh here to ensure all required process groups are # initialized: diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index bafefddbec..bef597be24 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -200,7 +200,9 @@ def build_pipeline_schedule( f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." ) + # pyrefly: ignore [bad-instantiation] schedule = schedule_class( + # pyrefly: ignore [bad-argument-type] stages if looped_schedule else stages[0], n_microbatches=n_microbatches, loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches), @@ -225,6 +227,7 @@ def build_pipeline_schedule( "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), " "and _PipelineScheduleRuntime support csv schedules" ) + # pyrefly: ignore [missing-attribute] schedule._load_csv(pp_schedule_csv) return schedule @@ -445,7 +448,7 @@ def _build_stage_from_modules( "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" ) - def _get_stage_indices() -> tuple[int]: + def _get_stage_indices() -> tuple[int, ...]: """ Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule @@ -464,6 +467,8 @@ def _get_stage_indices() -> tuple[int]: zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) ) return stage_v_pairs[pp_rank] + else: + raise ValueError(f"Unknown style {style}") for stage_idx in _get_stage_indices(): module_names = module_names_per_stage[stage_idx] diff --git a/torchtitan/distributed/tensor_parallel.py b/torchtitan/distributed/tensor_parallel.py index 04e4e36c3a..59fffc86a2 100644 --- a/torchtitan/distributed/tensor_parallel.py +++ b/torchtitan/distributed/tensor_parallel.py @@ -6,6 +6,7 @@ import torch +import torch._inductor.config from torch.distributed.device_mesh import DeviceMesh from torchtitan.config import JobConfig diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 6a73ffd083..811e062958 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -7,12 +7,16 @@ import contextlib import math import os -from collections.abc import Generator, Iterable +from abc import abstractmethod +from collections.abc import Iterable from datetime import timedelta +from typing import Protocol import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d +import torch.distributed.tensor._random +import torch.distributed.tensor.parallel from torch import distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor @@ -174,11 +178,15 @@ def set_determinism( # Filter out all distinct dimensions to get duplicate_seed_mesh duplicate_seed_mesh_dims = [ name + # pyrefly: ignore [not-iterable] for name in world_mesh.mesh_dim_names if name not in distinct_dims_in_mesh ] duplicate_seed_mesh = ( - world_mesh[duplicate_seed_mesh_dims] if duplicate_seed_mesh_dims else None + # pyrefly: ignore [bad-index] + world_mesh[duplicate_seed_mesh_dims] + if duplicate_seed_mesh_dims + else None ) else: duplicate_seed_mesh = world_mesh @@ -192,6 +200,7 @@ def set_determinism( # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. # IF PP is also used, this seed is unique per PP rank. if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None: + # pyrefly: ignore [bad-argument-type] torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) @@ -205,11 +214,11 @@ def create_context_parallel_ctx( try: from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental._attention import set_rotate_method - except ImportError: - print( + except ImportError as e: + raise ValueError( f"PyTorch version {torch.__version__} does not include the experimental " "Context Parallel API. Please update to a newer version." - ) + ) from e set_rotate_method(cp_rotate_method) return context_parallel( @@ -220,9 +229,18 @@ def create_context_parallel_ctx( ) -def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]: +class TrainContext(Protocol): + @abstractmethod + def __call__( + self, + cp_context: contextlib.AbstractContextManager[None] | None = None, + ) -> contextlib.AbstractContextManager[None]: + pass + + +def get_train_context(enable_loss_parallel: bool) -> TrainContext: @contextlib.contextmanager - def context(cp_context: Generator[None, None, None] | None = None): + def context(cp_context: contextlib.AbstractContextManager[None] | None = None): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) @@ -236,8 +254,8 @@ def context(cp_context: Generator[None, None, None] | None = None): def maybe_enable_amp( - parallel_dims: ParallelDims, mixed_precision_param: str, device_type: torch.device -) -> Generator[None, None, None]: + parallel_dims: ParallelDims, mixed_precision_param: str, device_type: str +) -> contextlib.AbstractContextManager[None]: if parallel_dims.fsdp_enabled: # FSDP handles mixed precision internally logger.info("Mixed precision training is handled by fully_shard") @@ -252,6 +270,7 @@ def maybe_enable_amp( else: # the following code will only be executed for DDP or single-device training logger.info("Mixed precision training is handled by AMP") + # pyrefly: ignore [bad-return] return torch.autocast( device_type, dtype=TORCH_DTYPE_MAP[mixed_precision_param], @@ -367,7 +386,9 @@ def set_pg_timeouts(timeout, world_mesh): # otherwise, some ranks may issue collectives with the new/shorter timeout and # those may time out, before other ranks have finished with initialization done # under the old/slow timeout. + # pyrefly: ignore [missing-attribute] torch.distributed.barrier(device_ids=[device_module.current_device()]) + # pyrefly: ignore [missing-attribute] device_module.synchronize() groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] @@ -477,6 +498,7 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) + # pyrefly: ignore [not-iterable] if "ep" in p.device_mesh.mesh_dim_names: ep_params.append(p) ep_grads.append(p.grad) @@ -491,6 +513,7 @@ def _clip_grad_norm_with_ep( if isinstance(ep_grads_total_norm, DTensor): ep_grads_total_norm = ep_grads_total_norm.full_tensor() + # pyrefly: ignore [missing-attribute] non_ep_grads_total_norm = torch.nn.utils.get_total_norm( non_ep_grads, norm_type, error_if_nonfinite, foreach ).full_tensor() diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 493cd1abb4..63790b8862 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -153,7 +153,7 @@ def load_state_dict(self, state_dict): self._data.load_state_dict(state_dict["data"]) def state_dict(self): - _state_dict = {"token_buffer": self._token_buffer} + _state_dict: dict[str, Any] = {"token_buffer": self._token_buffer} if isinstance(self._data, Dataset): _state_dict["sample_idx"] = self._sample_idx diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 819dbd57bc..b04a6a136e 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -20,6 +20,7 @@ ) from torch.nn.attention.varlen import varlen_attn +from torch.types import Number __all__ = [ @@ -43,8 +44,8 @@ class VarlenMetadata(NamedTuple): cu_seq_q: torch.Tensor cu_seq_k: torch.Tensor - max_q: int - max_k: int + max_q: Number + max_k: Number class VarlenAttentionWrapper(torch.nn.Module): @@ -66,8 +67,11 @@ def forward( max_k = attention_masks.max_k n_local_heads = xq.shape[1] + # pyrefly: ignore [no-matching-overload] xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + # pyrefly: ignore [no-matching-overload] xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + # pyrefly: ignore [no-matching-overload] xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim) return VarlenAttentionWrapper._compiled_varlen_attn( @@ -146,7 +150,7 @@ class ScaledDotProductAttentionWrapper(torch.nn.Module): """ # TODO: remove sdpa_backends after PyTorch 2.9 is released. - sdpa_backends: ClassVar[list[SDPBackend]] = [] + sdpa_backends: list[SDPBackend] = [] def __init__(self) -> None: super().__init__() diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index c068e60a30..63fb910376 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -119,6 +119,7 @@ def parallelize_deepseekv3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -225,6 +226,7 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -246,6 +248,7 @@ def apply_non_moe_tp( "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if transformer_block.attention.q_lora_rank == 0: layer_plan.update( { @@ -263,6 +266,7 @@ def apply_non_moe_tp( } ) + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -277,8 +281,10 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index e683905878..64a9d2bb81 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -101,16 +101,17 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type != "sdpa" + ): raise NotImplementedError("CP support is only supported for SDPA.") self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 5b17ad0acf..26e0cff2f3 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -240,6 +240,7 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): case "flex": self.inner_attention = FlexAttentionWrapper() case _: + # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() def forward( @@ -433,6 +434,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device=buffer_device) if self.norm is not None: self.norm.reset_parameters() diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index fd4ec30284..7fd6743600 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -106,6 +106,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): if "moe.experts" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] @@ -128,15 +129,19 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: # keep this path for offline conversion split_values = self._split_experts_weights( - value, self.model_args.moe_args.num_experts + value, + # pyrefly: ignore [missing-attribute] + self.model_args.moe_args.num_experts, ) + # pyrefly: ignore [missing-attribute] for expert_num in range(0, self.model_args.moe_args.num_experts): new_key = new_abstract_key.format(layer_num, expert_num) hf_state_dict[new_key] = split_values[expert_num].squeeze() elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] new_key = new_key.format(layer_num) @@ -186,6 +191,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_weights_by_layer, titan_abstract_key, layer_num, + # pyrefly: ignore [missing-attribute] self.model_args.moe_args.num_experts, ) @@ -194,6 +200,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] new_key = new_key.format(layer_num) diff --git a/torchtitan/models/flux/__init__.py b/torchtitan/models/flux/__init__.py index 0fee76e60d..d5ec94b1d6 100644 --- a/torchtitan/models/flux/__init__.py +++ b/torchtitan/models/flux/__init__.py @@ -20,6 +20,7 @@ __all__ = [ "FluxModelArgs", "FluxModel", + # pyrefly: ignore [missing-module-attribute] "flux_configs", "parallelize_flux", ] diff --git a/torchtitan/models/flux/flux_datasets.py b/torchtitan/models/flux/flux_datasets.py index f3cf283aa6..906b669001 100644 --- a/torchtitan/models/flux/flux_datasets.py +++ b/torchtitan/models/flux/flux_datasets.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Optional import numpy as np -import PIL +import PIL.Image import torch from datasets import Dataset, load_dataset @@ -271,6 +271,7 @@ def __iter__(self): # skip low quality image or image with color channel = 1 if sample_dict["image"] is None: + # pyrefly: ignore [missing-attribute] sample = sample.get("__key__", "unknown") logger.warning( f"Low quality image {sample} is skipped in Flux Dataloader." @@ -279,6 +280,7 @@ def __iter__(self): # Classifier-free guidance: Replace some of the strings with empty strings. # Distinct random seed is initialized at the beginning of training for each FSDP rank. + # pyrefly: ignore [missing-attribute] dropout_prob = self.job_config.training.classifier_free_guidance_prob if dropout_prob > 0.0: if torch.rand(1).item() < dropout_prob: diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index b89887ad51..bffdb2a2e7 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -25,6 +25,7 @@ def inference(config: JobConfig): # Distributed processing setup: Each GPU/process handles a subset of prompts world_size = int(os.environ["WORLD_SIZE"]) global_rank = int(os.environ["RANK"]) + # pyrefly: ignore [missing-attribute] original_prompts = open(config.inference.prompts_path).readlines() total_prompts = len(original_prompts) @@ -45,10 +46,12 @@ def inference(config: JobConfig): if prompts: # Generate images for this process's assigned prompts + # pyrefly: ignore [missing-attribute] bs = config.inference.local_batch_size output_dir = os.path.join( config.job.dump_folder, + # pyrefly: ignore [missing-attribute] config.inference.save_img_folder, ) # Create mapping from local indices to global prompt indices @@ -59,6 +62,7 @@ def inference(config: JobConfig): device=trainer.device, dtype=trainer._dtype, job_config=trainer.job_config, + # pyrefly: ignore [bad-argument-type] model=trainer.model_parts[0], prompt=prompts[i : i + bs], autoencoder=trainer.autoencoder, diff --git a/torchtitan/models/flux/inference/sampling.py b/torchtitan/models/flux/inference/sampling.py index f43d0fc2c5..5ee48ab60f 100644 --- a/torchtitan/models/flux/inference/sampling.py +++ b/torchtitan/models/flux/inference/sampling.py @@ -93,10 +93,13 @@ def generate_image( prompt = [prompt] # allow for packing and conversion to latent space. Use the same resolution as training time. + # pyrefly: ignore [missing-attribute] img_height = 16 * (job_config.training.img_size // 16) + # pyrefly: ignore [missing-attribute] img_width = 16 * (job_config.training.img_size // 16) enable_classifier_free_guidance = ( + # pyrefly: ignore [missing-attribute] job_config.validation.enable_classifier_free_guidance ) @@ -104,7 +107,9 @@ def generate_image( clip_tokens = clip_tokenizer.encode(prompt) t5_tokens = t5_tokenizer.encode(prompt) if len(prompt) == 1: + # pyrefly: ignore [missing-attribute] clip_tokens = clip_tokens.unsqueeze(0) + # pyrefly: ignore [missing-attribute] t5_tokens = t5_tokens.unsqueeze(0) batch = preprocess_data( @@ -113,6 +118,7 @@ def generate_image( autoencoder=None, clip_encoder=clip_encoder, t5_encoder=t5_encoder, + # pyrefly: ignore [bad-argument-type] batch={ "clip_tokens": clip_tokens, "t5_tokens": t5_tokens, @@ -124,7 +130,9 @@ def generate_image( empty_clip_tokens = clip_tokenizer.encode("") empty_t5_tokens = t5_tokenizer.encode("") + # pyrefly: ignore [missing-attribute] empty_clip_tokens = empty_clip_tokens.repeat(num_images, 1) + # pyrefly: ignore [missing-attribute] empty_t5_tokens = empty_t5_tokens.repeat(num_images, 1) empty_batch = preprocess_data( @@ -145,16 +153,24 @@ def generate_image( model=model, img_width=img_width, img_height=img_height, + # pyrefly: ignore [missing-attribute] denoising_steps=job_config.validation.denoising_steps, clip_encodings=batch["clip_encodings"], t5_encodings=batch["t5_encodings"], enable_classifier_free_guidance=enable_classifier_free_guidance, empty_t5_encodings=( - empty_batch["t5_encodings"] if enable_classifier_free_guidance else None + # pyrefly: ignore [unbound-name] + empty_batch["t5_encodings"] + if enable_classifier_free_guidance + else None ), empty_clip_encodings=( - empty_batch["clip_encodings"] if enable_classifier_free_guidance else None + # pyrefly: ignore [unbound-name] + empty_batch["clip_encodings"] + if enable_classifier_free_guidance + else None ), + # pyrefly: ignore [missing-attribute] classifier_free_guidance_scale=job_config.validation.classifier_free_guidance_scale, ) @@ -190,7 +206,9 @@ def denoise( if enable_classifier_free_guidance: # Double batch size for CFG: [unconditional, conditional] latents = torch.cat([latents, latents], dim=0) + # pyrefly: ignore [no-matching-overload] t5_encodings = torch.cat([empty_t5_encodings, t5_encodings], dim=0) + # pyrefly: ignore [no-matching-overload] clip_encodings = torch.cat([empty_clip_encodings, clip_encodings], dim=0) bsz *= 2 diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index fc9c926af0..b27fa93a31 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Any import torch import torch.nn as nn @@ -77,7 +78,7 @@ def apply_fsdp( cpu_offload (bool): Whether to offload model parameters to CPU. Defaults to False. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + fsdp_config: dict[str, Any] = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() @@ -88,21 +89,27 @@ def apply_fsdp( model.txt_in, ] for layer in linear_layers: + # pyrefly: ignore [no-matching-overload] fully_shard(layer, **fsdp_config) + # pyrefly: ignore [not-iterable] for block in model.double_blocks: + # pyrefly: ignore [no-matching-overload] fully_shard( block, **fsdp_config, ) + # pyrefly: ignore [not-iterable] for block in model.single_blocks: + # pyrefly: ignore [no-matching-overload] fully_shard( block, **fsdp_config, ) # apply FSDP to last layer. Set reshard_after_forward=False for last layer to avoid gather right after reshard + # pyrefly: ignore [no-matching-overload] fully_shard(model.final_layer, **fsdp_config, reshard_after_forward=False) # Wrap all the rest of model @@ -112,12 +119,16 @@ def apply_fsdp( def apply_ac(model: nn.Module, ac_config): """Apply activation checkpointing to the model.""" + # pyrefly: ignore [missing-attribute] for layer_id, block in model.double_blocks.named_children(): block = ptd_checkpoint_wrapper(block, preserve_rng_state=False) + # pyrefly: ignore [missing-attribute] model.double_blocks.register_module(layer_id, block) + # pyrefly: ignore [missing-attribute] for layer_id, block in model.single_blocks.named_children(): block = ptd_checkpoint_wrapper(block, preserve_rng_state=False) + # pyrefly: ignore [missing-attribute] model.single_blocks.register_module(layer_id, block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") @@ -139,7 +150,7 @@ def parallelize_encoders( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - fsdp_config = { + fsdp_config: dict[str, Any] = { "mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], "mp_policy": mp_policy, } @@ -148,8 +159,10 @@ def parallelize_encoders( # NOTE: only apply FSDP to the T5 encoder, not the CLIP text encoder. # CLIP Text encoder has low computation / communication ratio, so it's not necessary to apply FSDP to it. + # pyrefly: ignore [missing-attribute] for block in t5_model.hf_module.encoder.block: fully_shard(block, **fsdp_config) + # pyrefly: ignore [no-matching-overload] fully_shard(t5_model.hf_module, **fsdp_config) if parallel_dims.dp_replicate_enabled: diff --git a/torchtitan/models/flux/model/autoencoder.py b/torchtitan/models/flux/model/autoencoder.py index dc6fb1d061..9ca46dff96 100644 --- a/torchtitan/models/flux/model/autoencoder.py +++ b/torchtitan/models/flux/model/autoencoder.py @@ -19,7 +19,7 @@ class AutoEncoderParams: in_channels: int = 3 ch: int = 128 out_ch: int = 3 - ch_mult: tuple[int] = (1, 2, 4, 4) + ch_mult: tuple[int, ...] = (1, 2, 4, 4) num_res_blocks: int = 2 z_channels: int = 16 scale_factor: float = 0.3611 @@ -191,17 +191,24 @@ def forward(self, x: Tensor) -> Tensor: hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): + # pyrefly: ignore [bad-index, not-callable] h = self.down[i_level].block[i_block](hs[-1]) + # pyrefly: ignore [bad-argument-type] if len(self.down[i_level].attn) > 0: + # pyrefly: ignore [bad-index, not-callable] h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: + # pyrefly: ignore [not-callable] hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] + # pyrefly: ignore [not-callable] h = self.mid.block_1(h) + # pyrefly: ignore [not-callable] h = self.mid.attn_1(h) + # pyrefly: ignore [not-callable] h = self.mid.block_2(h) # end h = self.norm_out(h) @@ -276,8 +283,11 @@ def forward(self, z: Tensor) -> Tensor: h = self.conv_in(z) # middle + # pyrefly: ignore [not-callable] h = self.mid.block_1(h) + # pyrefly: ignore [not-callable] h = self.mid.attn_1(h) + # pyrefly: ignore [not-callable] h = self.mid.block_2(h) # cast to proper dtype @@ -285,10 +295,14 @@ def forward(self, z: Tensor) -> Tensor: # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): + # pyrefly: ignore [bad-index, not-callable] h = self.up[i_level].block[i_block](h) + # pyrefly: ignore [bad-argument-type] if len(self.up[i_level].attn) > 0: + # pyrefly: ignore [bad-index, not-callable] h = self.up[i_level].attn[i_block](h) if i_level != 0: + # pyrefly: ignore [not-callable] h = self.up[i_level].upsample(h) # end @@ -321,6 +335,7 @@ def __init__(self, params: AutoEncoderParams): resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, + # pyrefly: ignore [bad-argument-type] ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, @@ -330,6 +345,7 @@ def __init__(self, params: AutoEncoderParams): in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, + # pyrefly: ignore [bad-argument-type] ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, diff --git a/torchtitan/models/flux/model/hf_embedder.py b/torchtitan/models/flux/model/hf_embedder.py index 90be8767a9..89bed4d248 100644 --- a/torchtitan/models/flux/model/hf_embedder.py +++ b/torchtitan/models/flux/model/hf_embedder.py @@ -19,6 +19,7 @@ def __init__(self, version: str, random_init=False, **hf_kwargs): if random_init: # Initialize CLIP model with random weights for test purpose only self.hf_module = CLIPTextModel._from_config( + # pyrefly: ignore [missing-attribute] CLIPTextModel.config_class.from_pretrained( os.path.join(version, "config.json"), **hf_kwargs ) @@ -31,6 +32,7 @@ def __init__(self, version: str, random_init=False, **hf_kwargs): if random_init: # Initialize T5 model with random weights for test purpose only self.hf_module = T5EncoderModel._from_config( + # pyrefly: ignore [missing-attribute] T5EncoderModel.config_class.from_pretrained( os.path.join(version, "config.json"), **hf_kwargs ) diff --git a/torchtitan/models/flux/model/layers.py b/torchtitan/models/flux/model/layers.py index 923c5a422c..30ba52d3a3 100644 --- a/torchtitan/models/flux/model/layers.py +++ b/torchtitan/models/flux/model/layers.py @@ -6,6 +6,7 @@ # imported from black-forest-labs/FLUX import math +from collections.abc import Sequence from dataclasses import dataclass import torch @@ -34,7 +35,7 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: list[int]): + def __init__(self, dim: int, theta: int, axes_dim: Sequence[int]): super().__init__() self.dim = dim self.theta = theta @@ -213,7 +214,9 @@ def init_weights(self): self.txt_mlp[0], self.txt_mlp[2], ): + # pyrefly: ignore [bad-argument-type] nn.init.xavier_uniform_(layer.weight) + # pyrefly: ignore [bad-argument-type] nn.init.constant_(layer.bias, 0) # initialize Modulation layers, SelfAttention layers @@ -346,7 +349,9 @@ def __init__(self, hidden_size: int, patch_size: int, out_channels: int): ) def init_weights(self): + # pyrefly: ignore [bad-argument-type] nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + # pyrefly: ignore [bad-argument-type] nn.init.constant_(self.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.linear.weight, 0) nn.init.constant_(self.linear.bias, 0) diff --git a/torchtitan/models/flux/model/model.py b/torchtitan/models/flux/model/model.py index 6cfb02c9c0..d0f5592871 100644 --- a/torchtitan/models/flux/model/model.py +++ b/torchtitan/models/flux/model/model.py @@ -51,7 +51,9 @@ def __init__(self, model_args: FluxModelArgs): self.hidden_size = model_args.hidden_size self.num_heads = model_args.num_heads self.pe_embedder = EmbedND( - dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim + dim=pe_dim, + theta=model_args.theta, + axes_dim=model_args.axes_dim, ) self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) @@ -95,8 +97,10 @@ def init_weights(self, buffer_device=None): # Initialize transformer blocks: for block in self.single_blocks: + # pyrefly: ignore [not-callable] block.init_weights() for block in self.double_blocks: + # pyrefly: ignore [not-callable] block.init_weights() # Zero-out output layers: diff --git a/torchtitan/models/flux/model/state_dict_adapter.py b/torchtitan/models/flux/model/state_dict_adapter.py index c976df6919..2526bcd521 100644 --- a/torchtitan/models/flux/model/state_dict_adapter.py +++ b/torchtitan/models/flux/model/state_dict_adapter.py @@ -58,6 +58,7 @@ def __init__(self, model_args: FluxModelArgs, hf_assets_path: str | None): if hf_safetensors_indx: self.fqn_to_index_mapping = {} for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + # pyrefly: ignore [missing-attribute] indx = re.search(r"\d+", raw_indx).group(0) self.fqn_to_index_mapping[hf_key] = indx else: @@ -173,6 +174,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): # Extract layer_num and abstract key if necessary if "blocks" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -242,6 +244,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): # extract layer_num and abstract key if necessary if "blocks" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -273,6 +276,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: # combine collected values for tt_fqn, hf_fqn_map in to_combine.items(): + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", tt_fqn).group(0) tt_abstract_key = re.sub(r"(\d+)", "{}", tt_fqn, count=1) combine_values = [] diff --git a/torchtitan/models/flux/tokenizer.py b/torchtitan/models/flux/tokenizer.py index b5cca546b9..06fbde2bbb 100644 --- a/torchtitan/models/flux/tokenizer.py +++ b/torchtitan/models/flux/tokenizer.py @@ -46,6 +46,7 @@ def _pad_and_chunk_tokens( def get_vocab_size(self) -> int: return self.tiktokenizer.vocab_size + # pyrefly: ignore [bad-override] def encode(self, text: str | list[str]) -> torch.Tensor: """ Use TikTokenizer to encode the text into tokens, and then pad and chunk the tokens to max_length. @@ -72,6 +73,7 @@ def encode(self, text: str | list[str]) -> torch.Tensor: tokens = self._pad_and_chunk_tokens(tokens, self._max_length, self.pad_id) return torch.tensor(tokens) + # pyrefly: ignore [bad-override] def decode(self, t: List[int]) -> str: """ Decode function. This function will not be called. @@ -96,10 +98,12 @@ def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwar self.is_clip = "clip" in model_path.lower() if self.is_clip: + # pyrefly: ignore [bad-assignment] self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( model_path, max_length=max_length, **hf_kwargs ) else: + # pyrefly: ignore [bad-assignment] self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( model_path, max_length=max_length, **hf_kwargs ) @@ -107,6 +111,7 @@ def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwar def get_vocab_size(self) -> int: return self._tokenizer.vocab_size + # pyrefly: ignore [bad-override] def encode( self, s: str | list[str], @@ -125,6 +130,7 @@ def encode( )["input_ids"] return tokens + # pyrefly: ignore [bad-override] def decode(self, t: List[int]) -> str: """ Decode function. This function will not be called. @@ -136,11 +142,15 @@ def build_flux_tokenizer(job_config: JobConfig) -> tuple[BaseTokenizer, BaseToke """ Build the tokenizer for Flux. """ + # pyrefly: ignore [missing-attribute] t5_tokenizer_path = job_config.encoder.t5_encoder + # pyrefly: ignore [missing-attribute] clip_tokenzier_path = job_config.encoder.clip_encoder + # pyrefly: ignore [missing-attribute] max_t5_encoding_len = job_config.encoder.max_t5_encoding_len # NOTE: This tokenizer is used for offline CI and testing only, borrowed from llama3 tokenizer + # pyrefly: ignore [missing-attribute] if job_config.training.test_mode: tokenizer_class = FluxTestTokenizer t5_tokenizer_path = clip_tokenzier_path = job_config.model.hf_assets_path diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 5af9959050..3e008fba59 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -48,23 +48,31 @@ def __init__(self, job_config: JobConfig): model_args = self.train_spec.model_args[job_config.model.flavor] self.autoencoder = load_ae( + # pyrefly: ignore [missing-attribute] job_config.encoder.autoencoder_path, + # pyrefly: ignore [missing-attribute] model_args.autoencoder_params, device=self.device, dtype=self._dtype, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ) self.clip_encoder = FluxEmbedder( + # pyrefly: ignore [missing-attribute] version=job_config.encoder.clip_encoder, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ).to(device=self.device, dtype=self._dtype) self.t5_encoder = FluxEmbedder( + # pyrefly: ignore [missing-attribute] version=job_config.encoder.t5_encoder, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ).to(device=self.device, dtype=self._dtype) # Apply FSDP to the T5 model / CLIP model + # pyrefly: ignore [bad-assignment] self.t5_encoder, self.clip_encoder = parallelize_encoders( t5_model=self.t5_encoder, clip_model=self.clip_encoder, @@ -73,6 +81,7 @@ def __init__(self, job_config: JobConfig): ) if job_config.validation.enable: + # pyrefly: ignore [missing-attribute] self.validator.flux_init( device=self.device, _dtype=self._dtype, @@ -164,6 +173,7 @@ def forward_backward_step( loss = self.loss_fn(latent_noise_pred, target) # latent_noise_pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory + # pyrefly: ignore [unsupported-delete] del (latent_noise_pred, noise, target) loss.backward() diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 189385e0f2..32fa7b9f55 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Generator +from contextlib import AbstractContextManager import torch import torch.nn as nn @@ -15,7 +15,7 @@ from torchtitan.components.loss import LossFunction from torchtitan.components.metrics import MetricsProcessor from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.components.validate import Validator +from torchtitan.components.validate import ValidationContext, Validator from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.models.flux.flux_datasets import build_flux_validation_dataloader @@ -53,8 +53,8 @@ def __init__( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, @@ -63,6 +63,7 @@ def __init__( self.job_config = job_config self.parallel_dims = parallel_dims self.loss_fn = loss_fn + # pyrefly: ignore [missing-attribute] self.all_timesteps = self.job_config.validation.all_timesteps self.validation_dataloader = build_flux_validation_dataloader( job_config=job_config, @@ -74,6 +75,7 @@ def __init__( ) self.validation_context = validation_context self.maybe_enable_amp = maybe_enable_amp + # pyrefly: ignore [bad-assignment] self.metrics_processor = metrics_processor self.t5_tokenizer, self.clip_tokenizer = build_flux_tokenizer(self.job_config) @@ -91,6 +93,7 @@ def flux_init( t5_encoder: FluxEmbedder, clip_encoder: FluxEmbedder, ): + # pyrefly: ignore [read-only] self.device = device self._dtype = _dtype self.autoencoder = autoencoder @@ -109,9 +112,12 @@ def validate( model.eval() # Disable cfg dropout during validation + # pyrefly: ignore [missing-attribute] training_cfg_prob = self.job_config.training.classifier_free_guidance_prob + # pyrefly: ignore [missing-attribute] self.job_config.training.classifier_free_guidance_prob = 0.0 + # pyrefly: ignore [missing-attribute] save_img_count = self.job_config.validation.save_img_count parallel_dims = self.parallel_dims @@ -120,6 +126,7 @@ def validate( device_type = dist_utils.device_type num_steps = 0 + # pyrefly: ignore [not-iterable] for input_dict, labels in self.validation_dataloader: if ( self.job_config.validation.steps != -1 @@ -137,6 +144,7 @@ def validate( device=self.device, dtype=self._dtype, job_config=self.job_config, + # pyrefly: ignore [bad-argument-type] model=model, prompt=p, autoencoder=self.autoencoder, @@ -150,6 +158,7 @@ def validate( name=f"image_rank{str(torch.distributed.get_rank())}_{step}.png", output_dir=os.path.join( self.job_config.job.dump_folder, + # pyrefly: ignore [missing-attribute] self.job_config.validation.save_img_folder, ), x=image, @@ -270,6 +279,7 @@ def validate( model.train() # re-enable cfg dropout for training + # pyrefly: ignore [missing-attribute] self.job_config.training.classifier_free_guidance_prob = training_cfg_prob @@ -280,8 +290,8 @@ def build_flux_validator( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 13a968be96..63bbc19ff6 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -101,6 +101,7 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -202,6 +203,7 @@ def apply_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -226,8 +228,10 @@ def apply_tp( } parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) @@ -242,10 +246,12 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile( transformer_block, backend=compile_config.backend, fullgraph=True ) + # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) logger.info("Compiling each TransformerBlock with torch.compile") @@ -280,6 +286,7 @@ def apply_fsdp( mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: + # pyrefly: ignore [bad-typed-dict-key] fsdp_config["offload_policy"] = CPUOffloadPolicy() match reshard_after_forward_policy: @@ -297,11 +304,13 @@ def apply_fsdp( ) if model.tok_embeddings is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( model.tok_embeddings, **fsdp_config, reshard_after_forward=reshard_after_forward, ) + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): fully_shard( transformer_block, @@ -311,6 +320,7 @@ def apply_fsdp( # As an optimization, do not reshard_after_forward the last layers by default # since FSDP would prefetch them immediately after the forward pass if model.norm is not None and model.output is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( [model.norm, model.output], **fsdp_config, @@ -327,6 +337,7 @@ def apply_ddp( if enable_compile: torch._dynamo.config.optimize_ddp = "ddp_optimizer" + # pyrefly: ignore [invalid-param-spec] replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) logger.info("Applied DDP to the model") diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 81680074eb..79e97dab4c 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -62,9 +62,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for FlexAttention is still in progress." ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_dense_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 8982fcca9f..cafd58a52e 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -223,8 +223,10 @@ def __init__(self, model_args: TransformerModelArgs): case "flex": self.inner_attention = FlexAttentionWrapper() case "varlen": + # pyrefly: ignore [bad-assignment] self.inner_attention = VarlenAttentionWrapper() case _: + # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): @@ -474,6 +476,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights() if self.norm is not None: self.norm.reset_parameters() @@ -569,12 +572,15 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore [not-callable] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer( h, self.freqs_cis, attention_masks=attention_masks, positions=positions ) + # pyrefly: ignore [not-callable] h = self.norm(h) if self.norm else h + # pyrefly: ignore [not-callable] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 2c386ece0d..f951edd75a 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -81,6 +81,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): if "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] # We need to permute the weights in wq and wk layer in order to account for the difference between @@ -115,6 +116,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): if "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] @@ -132,5 +134,6 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: else: new_key = self.from_hf_map[key] + # pyrefly: ignore [unsupported-operation] state_dict[new_key] = value return state_dict diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 7440b3c3f5..8c7601bca4 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Any + import torch import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -130,6 +132,7 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -245,6 +248,7 @@ def apply_non_moe_tp( ) # Apply tensor + sequence parallelism to every transformer block + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -260,6 +264,7 @@ def apply_non_moe_tp( "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -274,8 +279,10 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) @@ -315,7 +322,7 @@ def apply_fsdp( """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + fsdp_config: dict[str, Any] = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() @@ -334,12 +341,14 @@ def apply_fsdp( ) if model.tok_embeddings is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( model.tok_embeddings, **fsdp_config, reshard_after_forward=reshard_after_forward, ) + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping # - the router and the shared experts are sharded together with the TransformerBlock @@ -386,6 +395,7 @@ def apply_fsdp( # As an optimization, do not reshard_after_forward the last layers by default # since FSDP would prefetch them immediately after the forward pass if model.norm is not None and model.output is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( [model.norm, model.output], **fsdp_config, @@ -400,49 +410,65 @@ def apply_fsdp( return # forward + # pyrefly: ignore [not-callable] transformer_blocks = list(model.layers.values()) next_transformer_blocks = transformer_blocks[1:] + [None] + # pyrefly: ignore [bad-argument-type] if model.tok_embeddings is not None and len(model.layers) > 0: + # pyrefly: ignore [missing-attribute] model.tok_embeddings.set_modules_to_forward_prefetch([transformer_blocks[0]]) for transformer_block, next_transformer_block in zip( transformer_blocks, next_transformer_blocks ): if next_transformer_block is not None: + # pyrefly: ignore [missing-attribute] if next_transformer_block.moe_enabled: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( + # pyrefly: ignore [missing-attribute] [next_transformer_block, next_transformer_block.moe.experts] ) else: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( [next_transformer_block] ) elif model.norm is not None and model.output is not None: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( [model.norm, model.output] ) # backward + # pyrefly: ignore [not-callable] reversed_transformer_blocks = list(reversed(model.layers.values())) prev_transformer_blocks = reversed_transformer_blocks[1:] + [None] + # pyrefly: ignore [bad-argument-type] if model.norm is not None and model.output is not None and len(model.layers) > 0: + # pyrefly: ignore [missing-attribute] model.output.set_modules_to_backward_prefetch([reversed_transformer_blocks[0]]) for transformer_block, prev_transformer_block in zip( reversed_transformer_blocks, prev_transformer_blocks ): if prev_transformer_block is not None: + # pyrefly: ignore [missing-attribute] if prev_transformer_block.moe_enabled: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch( + # pyrefly: ignore [missing-attribute] [prev_transformer_block, prev_transformer_block.moe.experts] ) else: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch( [prev_transformer_block] ) elif model.tok_embeddings is not None: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings]) @@ -456,7 +482,9 @@ def apply_moe_ep_tp( ): assert ep_mesh is not None or tp_mesh is not None + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue @@ -478,9 +506,12 @@ def apply_moe_ep_tp( # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) + # pyrefly: ignore [no-matching-overload] moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) + # pyrefly: ignore [missing-attribute] if transformer_block.moe.shared_experts is not None: # input Replicate, output Partial + # pyrefly: ignore [no-matching-overload] moe_layer_plan.update( { "moe.shared_experts.w1": ColwiseParallel(), @@ -491,8 +522,10 @@ def apply_moe_ep_tp( } ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=moe_layer_plan, ) @@ -513,6 +546,7 @@ def apply_moe_ep_tp( experts_plan = DualPipeExpertParallel(experts_plan) parallelize_module( + # pyrefly: ignore [missing-attribute] module=transformer_block.moe.experts, device_mesh=experts_mesh, parallelize_plan=experts_plan, @@ -528,7 +562,9 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b # but it is experimental. torch._dynamo.config.capture_scalar_outputs = True # Workaround for https://github.com/pytorch/pytorch/issues/166926 + # pyrefly: ignore [missing-attribute] torch._C._dynamo.eval_frame._set_lru_cache(False) + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): if transformer_block.moe_enabled: # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break @@ -582,6 +618,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b fullgraph=True, ) + # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) moe_module._run_experts_grouped_mm = torch.compile( diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index a277ca382e..3520e7e519 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -86,9 +86,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: job_config.debug.moe_force_load_balance ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 7c4f073e19..e08f733f28 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -231,6 +231,7 @@ def __init__( case "flex": self.inner_attention = FlexAttentionWrapper() case _: + # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self, init_std: float): @@ -513,6 +514,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device=buffer_device) if self.norm is not None: self.norm.reset_parameters() @@ -590,11 +592,14 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore [not-callable] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer(h, self.freqs_cis, attention_masks, positions) + # pyrefly: ignore [not-callable] h = self.norm(h) if self.norm else h + # pyrefly: ignore [not-callable] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama4/model/state_dict_adapter.py b/torchtitan/models/llama4/model/state_dict_adapter.py index 182981c665..c272b2ac10 100644 --- a/torchtitan/models/llama4/model/state_dict_adapter.py +++ b/torchtitan/models/llama4/model/state_dict_adapter.py @@ -52,6 +52,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: to_combine = defaultdict(dict) for key, value in state_dict.items(): if "layers" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -77,6 +78,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_abstract_key = ( "language_model.model.layers.{}.feed_forward.experts.gate_up_proj" ) + # pyrefly: ignore [unnecessary-comparison] if hf_abstract_key is None: continue to_combine[hf_abstract_key.format(layer_num)][ @@ -85,6 +87,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: # combine collected values for hf_fqn, tt_fqn_map in to_combine.items(): + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", hf_fqn).group(0) combine_values = [] # put into correct order to combine @@ -106,6 +109,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): if "layers" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: diff --git a/torchtitan/models/moe/kernels.py b/torchtitan/models/moe/kernels.py index 7aac7b3ac4..a1b1d17771 100644 --- a/torchtitan/models/moe/kernels.py +++ b/torchtitan/models/moe/kernels.py @@ -92,8 +92,11 @@ def fill_indices_wrapper( start_index_values, write_offsets, permuted_indices, + # pyrefly: ignore [bad-argument-type] experts_per_rank, + # pyrefly: ignore [bad-argument-type] num_ranks, + # pyrefly: ignore [bad-argument-type] BLOCK_SIZE=block_size, ) return permuted_indices diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 741c908eab..da58c68b03 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -77,20 +77,20 @@ def _run_experts_for_loop( num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() + num_tokens_per_expert_list = num_tokens_per_expert.tolist() # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) + num_padding = x.shape[0] - sum(num_tokens_per_expert_list) # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, + x_splits = torch.split( + x[: sum(num_tokens_per_expert_list)], + split_size_or_sections=num_tokens_per_expert_list, dim=0, ) out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): + for expert_idx, x_expert in enumerate(x_splits): h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) @@ -148,7 +148,9 @@ def forward( # Convert parameters from DTensors to plain Tensors, to work with # dynamic-shape inputs in EP which cannot be easily expressed as DTensors. w1 = self.w1.to_local() + # pyrefly: ignore [missing-attribute] w2 = self.w2.to_local() + # pyrefly: ignore [missing-attribute] w3 = self.w3.to_local() else: w1 = self.w1 @@ -161,6 +163,7 @@ def forward( # otherwise, EP will handle the padding. if ( not isinstance(self.w1, DTensor) + # pyrefly: ignore [not-iterable] or "ep" not in self.w1.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 5f9f0a73be..c2eaed8de6 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -8,6 +8,7 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. import torch +import torch._inductor.config import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh @@ -121,6 +122,7 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -181,7 +183,9 @@ def parallelize_qwen3( ) # Enable weight tying after applying parallelisms + # pyrefly: ignore [missing-attribute] if model.model_args.enable_weight_tying: + # pyrefly: ignore [missing-attribute] model.output.weight = model.tok_embeddings.weight return model @@ -242,6 +246,7 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -260,6 +265,7 @@ def apply_non_moe_tp( "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -274,8 +280,10 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 2def3a949a..d0a0556bf1 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -59,7 +59,5 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: job_config.debug.moe_force_load_balance ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops(self, model, 2 * self.head_dim, seq_len) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 62b5d0c381..0683b4c42d 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -160,6 +160,9 @@ class Attention(nn.Module): """ + q_norm: nn.RMSNorm | None + k_norm: nn.RMSNorm | None + def __init__(self, model_args: Qwen3ModelArgs): super().__init__() self.n_heads = model_args.n_heads @@ -199,8 +202,10 @@ def __init__(self, model_args: Qwen3ModelArgs): case "flex": self.inner_attention = FlexAttentionWrapper() case "varlen": + # pyrefly: ignore [bad-assignment] self.inner_attention = VarlenAttentionWrapper() case "sdpa": + # pyrefly: ignore [bad-assignment] self.inner_attention = ScaledDotProductAttentionWrapper() case _: raise ValueError(f"Unknown attention type: {self.attn_type}") @@ -476,6 +481,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device) if self.norm is not None: self.norm.reset_parameters() @@ -567,11 +573,14 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore [not-callable] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer(h, self.rope_cache, attention_masks, positions) + # pyrefly: ignore [not-callable] h = self.norm(h) if self.norm else h + # pyrefly: ignore [not-callable] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/qwen3/model/state_dict_adapter.py b/torchtitan/models/qwen3/model/state_dict_adapter.py index 11bb8058c0..8dfe4d5aa7 100644 --- a/torchtitan/models/qwen3/model/state_dict_adapter.py +++ b/torchtitan/models/qwen3/model/state_dict_adapter.py @@ -63,6 +63,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) if abstract_key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] @@ -85,9 +86,12 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: # keep this path for offline conversion split_values = self._split_experts_weights( - value, self.model_args.moe_args.num_experts + value, + # pyrefly: ignore [missing-attribute] + self.model_args.moe_args.num_experts, ) + # pyrefly: ignore [missing-attribute] for expert_num in range(self.model_args.moe_args.num_experts): new_key = new_abstract_key.format(layer_num, expert_num) hf_state_dict[new_key] = split_values[expert_num].squeeze() @@ -96,6 +100,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) if abstract_key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] new_key = new_key.format(layer_num) @@ -104,6 +109,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: if key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] if self.model_args.enable_weight_tying and key == "output.weight": continue new_key = to_hf_map[key] @@ -121,6 +127,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} if ( + # pyrefly: ignore [missing-attribute] self.model_args.enable_weight_tying and "lm_head.weight" not in hf_state_dict ): @@ -132,6 +139,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=2) layer_num, expert_num = re.findall(r"\d+", key) titan_abstract_key = self.from_hf_map[abstract_key] + assert titan_abstract_key is not None new_key = titan_abstract_key.format(layer_num) # Store the expert's weight in expert_weights_by_layer for concatenating later. @@ -155,6 +163,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_weights_by_layer, titan_abstract_key, layer_num, + # pyrefly: ignore [missing-attribute] self.model_args.moe_args.num_experts, ) @@ -163,13 +172,16 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] + # pyrefly: ignore [missing-attribute] new_key = new_key.format(layer_num) state_dict[new_key] = value else: new_key = self.from_hf_map[key] + # pyrefly: ignore [unsupported-operation] state_dict[new_key] = value return state_dict diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index addfa17421..5bf73fbb7e 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -96,12 +96,13 @@ def _caculate_indices_from_placements( dim_size: int, dtensor_placements: tuple, device_mesh: DeviceMesh, - ) -> tuple[int, int]: + ) -> tuple[int | None, int | None]: mesh_names = [] dim_i_placements = [] # Find all the device mesh dimensios that shard on dim-i + # pyrefly: ignore [bad-argument-type] for i, name in enumerate(device_mesh.mesh_dim_names): placement = dtensor_placements[i] if placement.dim == dim: @@ -181,7 +182,9 @@ def _get_local_experts_weights( Returns: Dictionary mapping individual expert keys to their DTensor weights """ + # pyrefly: ignore [missing-attribute] device_mesh = grouped_expert_weight.device_mesh + # pyrefly: ignore [missing-attribute] dtensor_placements = grouped_expert_weight.placements # Step 1: Extract dimension-0 placement information @@ -212,6 +215,7 @@ def _get_local_experts_weights( elif isinstance(placement, _StridedShard): # Keep strided shard with same parameters new_placements.append( + # pyrefly: ignore [unexpected-positional-argument] _StridedShard(placement.dim, placement.split_factor) ) else: @@ -284,6 +288,7 @@ def _concatenate_expert_weights_dtensor( sorted_expert_ids = sorted(experts.keys()) sorted_experts = [experts[i] for i in sorted_expert_ids] + # pyrefly: ignore [missing-attribute] local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor assert ( @@ -306,7 +311,7 @@ def _concatenate_expert_weights_dtensor( def _split_experts_weights( self, weight: torch.Tensor, n_experts: int - ) -> list[torch.Tensor]: + ) -> tuple[torch.Tensor, ...]: """ Split the weights of the experts into a list of tensors. Used for offline conversion. @@ -365,7 +370,7 @@ def get_dense_model_nparams_and_flops( model: nn.Module, head_dims: int, seq_len: int, -) -> tuple[int, float]: +) -> tuple[int, int]: """ Args: model_args: BaseModelArgs object containing model configuration parameters. @@ -395,6 +400,7 @@ def get_dense_model_nparams_and_flops( # 4. we follow the convention and do not account for sparsity in causal attention num_flops_per_token = ( 6 * (nparams - nparams_embedding) + # pyrefly: ignore [missing-attribute] + 6 * model_args.n_layers * model_args.n_heads * head_dims * seq_len ) @@ -410,7 +416,7 @@ def get_moe_model_nparams_and_flops( model: nn.Module, head_dims: int, seq_len: int, -) -> tuple[int, float]: +) -> tuple[int, int]: """ Calculate nparams and nflops for MoE models. @@ -450,6 +456,7 @@ def get_moe_model_nparams_and_flops( nparams_sparse_active = ( nparams_moe_router + nparams_shared_experts + # pyrefly: ignore [missing-attribute] + nparams_experts * model_args.moe_args.top_k // model_args.moe_args.num_experts ) @@ -460,6 +467,7 @@ def get_moe_model_nparams_and_flops( num_flops_per_token = ( 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + # pyrefly: ignore [missing-attribute] + 6 * model_args.n_layers * model_args.n_heads * head_dims * seq_len ) diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index 4cb193c31a..712449f2f6 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -37,9 +37,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: pass @abstractmethod - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: pass diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index e22692bd52..7b2b3ef3ad 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -8,7 +8,7 @@ import os import re from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict from torch.distributed.checkpoint import HuggingFaceStorageReader @@ -27,6 +27,8 @@ class BaseStateDictAdapter(ABC): hf_assets_path: path to HF assets folder containing tokenizer, model weights, etc. """ + fqn_to_index_mapping: Dict[Any, int] | None + @abstractmethod def __init__( self, @@ -98,6 +100,7 @@ def __init__( if hf_safetensors_indx: self.fqn_to_index_mapping = {} for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + # pyrefly: ignore [missing-attribute] indx = re.search(r"\d+", raw_indx).group(0) self.fqn_to_index_mapping[hf_key] = int(indx) else: diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index f398dba9b5..5c2b40b217 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -69,6 +69,7 @@ def trace_handler(prof): elif torch.xpu.is_available(): gpu_device_profiled = torch.profiler.ProfilerActivity.XPU with torch.profiler.profile( + # pyrefly: ignore [bad-argument-type] activities=[ torch.profiler.ProfilerActivity.CPU, gpu_device_profiled, diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 0b1c78d0d6..d2fa409223 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -65,7 +65,7 @@ def collect(reason: str, generation: int = 1): # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC -def get_peak_flops(device_name: str) -> int: +def get_peak_flops(device_name: str) -> float: try: # Run the lspci command and capture the output result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) diff --git a/torchtitan/train.py b/torchtitan/train.py index c897ee3c8a..8c597cd608 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -10,10 +10,11 @@ import os import time from datetime import timedelta -from typing import Any, Generator, Iterable +from typing import Any, Iterable import torch +import torch.distributed.checkpoint.stateful from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -60,7 +61,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # runtime utilities device: torch.device gc_handler: utils.GarbageCollection - train_context: Generator[None, None, None] + train_context: dist_utils.TrainContext gradient_accumulation_steps: int pp_has_first_stage: bool pp_has_last_stage: bool @@ -82,8 +83,10 @@ def __init__(self, job_config: JobConfig): importlib.import_module(job_config.experimental.custom_import) device_module, device_type = utils.device_module, utils.device_type + # pyrefly: ignore [read-only] self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. + # pyrefly: ignore [missing-attribute] device_module.set_device(self.device) # init distributed and build meshes @@ -99,6 +102,7 @@ def __init__(self, job_config: JobConfig): else: dp_degree, dp_rank = 1, 0 + # pyrefly: ignore [bad-argument-type] self.ft_manager = FTManager(job_config.fault_tolerance) dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) @@ -149,6 +153,7 @@ def __init__(self, job_config: JobConfig): # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) + # pyrefly: ignore [bad-argument-type] model_converters.convert(model) # metrics logging @@ -166,6 +171,7 @@ def __init__(self, job_config: JobConfig): ( model_param_count, self.metrics_processor.num_flops_per_token, + # pyrefly: ignore [bad-argument-type] ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) logger.info( @@ -242,10 +248,12 @@ def __init__(self, job_config: JobConfig): for m in self.model_parts: m.to_empty(device=init_device) with torch.no_grad(): + # pyrefly: ignore [not-callable] m.init_weights(buffer_device=buffer_device) m.train() # confirm that user will be able to view loss metrics on the console + # pyrefly: ignore [bad-argument-type] ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel @@ -253,6 +261,7 @@ def __init__(self, job_config: JobConfig): model.to_empty(device=init_device) with torch.no_grad(): + # pyrefly: ignore [not-callable] model.init_weights(buffer_device=buffer_device) model.train() @@ -458,6 +467,7 @@ def post_dataloading_process( attn_type = getattr(self.model_args, "attn_type", "sdpa") if attn_type in ["flex", "varlen"]: + # pyrefly: ignore [not-callable] extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, @@ -486,6 +496,7 @@ def forward_backward_step( optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( cp_mesh=parallel_dims.world_mesh["cp"], + # pyrefly: ignore [bad-argument-type] cp_buffers=cp_buffers, cp_seq_dims=cp_seq_dims, cp_no_restore_buffers={inputs, labels}, @@ -556,6 +567,7 @@ def train_step( # If data runs out during gradient accumulation, that # entire step will not be executed. for _microbatch in range(self.gradient_accumulation_steps): + # pyrefly: ignore [no-matching-overload] input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) @@ -636,6 +648,7 @@ def train(self): leaf_folder=leaf_folder, ) as memory_profiler, maybe_semi_sync_training( + # pyrefly: ignore [bad-argument-type] job_config.fault_tolerance, ft_manager=self.ft_manager, model=self.model_parts[0], @@ -652,6 +665,7 @@ def train(self): ), ), ): + # pyrefly: ignore [bad-argument-type] data_iterator = self.batch_generator(self.dataloader) while self.should_continue_training(): self.step += 1 @@ -671,7 +685,9 @@ def train(self): self.job_config.validation.enable and self.validator.should_validate(self.step) ): + # pyrefly: ignore [missing-attribute] with self.loss_fn.no_rescale(): + # pyrefly: ignore [bad-argument-count] self.validator.validate(self.model_parts, self.step) # signal the profiler that the next profiling step has started