Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ class Checkpoint:

@dataclass
class ActivationCheckpoint:
mode: Literal["selective", "full", "none"] = "selective"
mode: Literal["selective", "full", "memory_budget", "none"] = "selective"
"""Type of activation checkpointing to use"""

selective_ac_option: str = "2"
Expand Down Expand Up @@ -567,6 +567,24 @@ class ActivationCheckpoint:
rematerialized.
"""

memory_budget: float = 0.5
"""
When mode is set to "memory_budget", this value determines how much
partitioner in the compiler should trade off compute for memory.
0.0 corresponds to the activation memory from applying
activation checkpointing to the full compiled region, and 1.0 corresponds to
the activation memory from the default runtime-optimized strategy. Read here:
https://pytorch.org/blog/activation-checkpointing-techniques/
"""

visualize_memory_budget_pareto: bool = False
"""
This dumps out a SVG visualization of the expected runtime vs. activation
memory tradeoffs for all memory budget values from 0 to 1 in increments of
0.05 in {--job.dump_folder}/memory_budget_pareto folder. See an example here:
https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
"""


@dataclass
class Compile:
Expand Down
34 changes: 24 additions & 10 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# This file provides the util functions to apply activation checkpointing to the model.
# Technically, this is not a part of distributed, but distributed module is the best place to put it.

import os
from collections import defaultdict

import torch
Expand Down Expand Up @@ -279,6 +280,7 @@ def apply_ac(
model_compile_enabled: bool = False,
use_flex_attn: bool = False,
op_sac_save_list: set[torch._ops.OpOverload] | None = None,
base_folder: str = "",
) -> None:
"""Apply activation checkpointing to the model.

Expand All @@ -297,15 +299,27 @@ def apply_ac(
None
"""

for layer_id, transformer_block in model.layers.named_children():
transformer_block = _apply_ac_to_transformer_block(
transformer_block,
ac_config,
base_fqn=f"layers.{layer_id}",
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=op_sac_save_list,
)
model.layers.register_module(layer_id, transformer_block)
if ac_config.mode == "memory_budget":
assert model_compile_enabled, "Memory budget mode requires model to be compiled"
if ac_config.visualize_memory_budget_pareto:
pareto_dir = os.path.join(base_folder, "memory_budget_pareto")
if not os.path.exists(pareto_dir):
os.makedirs(pareto_dir, exist_ok=True)
torch._functorch.config.memory_budget_pareto_dir = pareto_dir
torch._functorch.config.visualize_memory_budget_pareto = True

torch._functorch.config.activation_memory_budget = ac_config.memory_budget
logger.info(f"Selected {ac_config.memory_budget} budget option")
else:
for layer_id, transformer_block in model.layers.named_children():
transformer_block = _apply_ac_to_transformer_block(
transformer_block,
ac_config,
base_fqn=f"layers.{layer_id}",
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=op_sac_save_list,
)
model.layers.register_module(layer_id, transformer_block)

logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
1 change: 1 addition & 0 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def parallelize_llama(
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
Expand Down
1 change: 1 addition & 0 deletions torchtitan/experiments/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def parallelize_qwen3(
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
Expand Down
1 change: 1 addition & 0 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def parallelize_llama(
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)

# apply data parallel
Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def parallelize_deepseekv3(
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)

if model_compile_enabled:
Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def parallelize_llama(
model_compile_enabled=model_compile_enabled,
use_flex_attn=use_flex_attn,
op_sac_save_list=_op_sac_save_list,
base_folder=job_config.job.dump_folder,
)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
Expand Down