From 5c562d16f9fb20444a47f93442077c0adb9eb6e0 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Thu, 25 Sep 2025 20:04:16 -0700 Subject: [PATCH] Adding config options for determinitic use_deterministic_algorithms() warn_only ac preserve_rng_state ac debug --- torchtitan/config/job_config.py | 10 ++++++++++ torchtitan/distributed/activation_checkpoint.py | 9 +++++---- torchtitan/distributed/utils.py | 3 ++- torchtitan/train.py | 1 + 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 0eb06a0d4..885b390f1 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -239,6 +239,8 @@ class Training: deterministic: bool = False """Use deterministic algorithms wherever possible, may be slower""" + deterministic_warn_only: bool =False + """Only warns about ops without deterministic implementations rather than erroring """ @dataclass class Parallelism: @@ -556,6 +558,14 @@ class ActivationCheckpoint: rematerialized. """ + preserve_rng_state: bool = False + """If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower""" + + determinism_check: str = "default" + """A string specifying the determinism function. """ + + debug: bool = False + """ Capture ac debug information. Will be slower. """ @dataclass class Compile: diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 227c2ca21..7cc1daa56 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -37,7 +37,7 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ac_freq = int(ac_config.selective_ac_option) if not ac_freq or _layer_sac_count % ac_freq == 0: return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop + module, preserve_rng_state=ac_config.preserve_rng_state, determinism_check=ac_config.determinism_check, early_stop=ac_config.early_stop, debug=ac_config.debug ) else: return module @@ -124,7 +124,9 @@ def selective_checkpointing_context_fn(): return ptd_checkpoint_wrapper( module, context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + debug=ac_config.debug, early_stop=ac_config.early_stop, ) @@ -140,8 +142,7 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: nn.Module: The module with full activation checkpointing applied. """ return ptd_checkpoint_wrapper( - module, preserve_rng_state=False, early_stop=ac_config.early_stop - ) + module, preserve_rng_state=ac_config.preserve_rng_state, determinism_check=ac_config.determinism_check, early_stop=ac_config.early_stop, debug=ac_config.debug) def _apply_op_sac_to_transformer_block_with_flex( diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 159b6229d..930183cfc 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -85,6 +85,7 @@ def set_determinism( device: torch.device, seed: int | None = None, deterministic: bool = False, + deterministic_warn_only: bool = False, distinct_seed_mesh_dim: str = "pp", ) -> None: """ @@ -99,7 +100,7 @@ def set_determinism( """ if deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") - torch.use_deterministic_algorithms(True) + torch.use_deterministic_algorithms(True, warn_only=deterministic_warn_only) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # env var for deterministic CuBLAS diff --git a/torchtitan/train.py b/torchtitan/train.py index 0afcac8dc..d1198a7d0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -128,6 +128,7 @@ def __init__(self, job_config: JobConfig): self.device, job_config.training.seed, job_config.training.deterministic, + job_config.training.deterministic_warn_only, ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name)