Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ class Training:
deterministic: bool = False
"""Use deterministic algorithms wherever possible, may be slower"""

deterministic_warn_only: bool =False
"""Only warns about ops without deterministic implementations rather than erroring """

@dataclass
class Parallelism:
Expand Down Expand Up @@ -556,6 +558,14 @@ class ActivationCheckpoint:
rematerialized.
"""

preserve_rng_state: bool = False
"""If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower"""

determinism_check: str = "default"
"""A string specifying the determinism function. """

debug: bool = False
""" Capture ac debug information. Will be slower. """

@dataclass
class Compile:
Expand Down
9 changes: 5 additions & 4 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
ac_freq = int(ac_config.selective_ac_option)
if not ac_freq or _layer_sac_count % ac_freq == 0:
return ptd_checkpoint_wrapper(
module, preserve_rng_state=False, early_stop=ac_config.early_stop
module, preserve_rng_state=ac_config.preserve_rng_state, determinism_check=ac_config.determinism_check, early_stop=ac_config.early_stop, debug=ac_config.debug
)
else:
return module
Expand Down Expand Up @@ -124,7 +124,9 @@ def selective_checkpointing_context_fn():
return ptd_checkpoint_wrapper(
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=False,
preserve_rng_state=ac_config.preserve_rng_state,
determinism_check=ac_config.determinism_check,
debug=ac_config.debug,
early_stop=ac_config.early_stop,
)

Expand All @@ -140,8 +142,7 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
nn.Module: The module with full activation checkpointing applied.
"""
return ptd_checkpoint_wrapper(
module, preserve_rng_state=False, early_stop=ac_config.early_stop
)
module, preserve_rng_state=ac_config.preserve_rng_state, determinism_check=ac_config.determinism_check, early_stop=ac_config.early_stop, debug=ac_config.debug)


def _apply_op_sac_to_transformer_block_with_flex(
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def set_determinism(
device: torch.device,
seed: int | None = None,
deterministic: bool = False,
deterministic_warn_only: bool = False,
distinct_seed_mesh_dim: str = "pp",
) -> None:
"""
Expand All @@ -99,7 +100,7 @@ def set_determinism(
"""
if deterministic:
logger.info("Deterministic algorithm enabled (expect perf degradation).")
torch.use_deterministic_algorithms(True)
torch.use_deterministic_algorithms(True, warn_only=deterministic_warn_only)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# env var for deterministic CuBLAS
Expand Down
1 change: 1 addition & 0 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, job_config: JobConfig):
self.device,
job_config.training.seed,
job_config.training.deterministic,
job_config.training.deterministic_warn_only,
)
self.train_spec = train_spec_module.get_train_spec(job_config.model.name)

Expand Down