Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,25 @@ def __init__(
self.enable = checkpoint_config.enable
self.load_only = checkpoint_config.load_only

self.states = states
self.states.update(
{
MODEL: ModelWrapper(model_parts),
OPTIMIZER: optimizers,
DATALOADER: dataloader,
LR_SCHEDULER: lr_schedulers,
}
)

self.ft_manager = (
ft_manager.manager
if ft_manager
and ft_manager.enabled
and checkpoint_config.enable_ft_dataloader_checkpoints
else None
ft_manager.manager if ft_manager and ft_manager.enabled else None
)

if ft_manager and ft_manager.enabled and not self.ft_manager:
self.enable_ft_dataloader_checkpoints = (
self.ft_manager and checkpoint_config.enable_ft_dataloader_checkpoints
)

if self.ft_manager and not self.enable_ft_dataloader_checkpoints:
logger.warn(
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
Expand Down Expand Up @@ -229,20 +239,11 @@ def load_state_dict(state_dict):
async_mode = checkpoint_config.async_mode.lower()
self.enable_staging = (
self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager
) or self.enable_ft_dataloader_checkpoints

if not self.enable and self.ft_manager is None:
if not self.enable and not self.enable_ft_dataloader_checkpoints:
return

self.states = states
self.states.update(
{
MODEL: ModelWrapper(model_parts),
OPTIMIZER: optimizers,
DATALOADER: dataloader,
LR_SCHEDULER: lr_schedulers,
}
)
self.ft_states = {DATALOADER: dataloader}

self.staging = False
Expand Down Expand Up @@ -279,7 +280,7 @@ def load_state_dict(state_dict):
if (
async_mode == AsyncMode.ASYNC
or async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
or self.ft_manager
or self.enable_ft_dataloader_checkpoints
):
self.pg = dist.new_group(backend="gloo")

Expand Down Expand Up @@ -480,14 +481,16 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
None
"""

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
self._ft_save(curr_step)

if not self._should_save(curr_step, last_step):
return

begin = time.monotonic()
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
if not self.enable_ft_dataloader_checkpoints or (
self.ft_manager and self.ft_manager.participating_rank() == 0
):
logger.info("Saving the checkpoint (or staging if async is enabled).")
checkpoint_id = self._create_checkpoint_id(curr_step)
self._async_wait()
Expand Down Expand Up @@ -530,7 +533,8 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
"Finished saving the checkpoint (or staging if async is enabled)"
f"in {time.monotonic() - begin:.2f} seconds."
)
elif self.ft_manager:
elif self.enable_ft_dataloader_checkpoints:
assert self.ft_manager is not None
logger.info(
"Replica %d doesn't save checkpoint.",
self.ft_manager.participating_rank(),
Expand All @@ -551,7 +555,7 @@ def load(self, step: int = -1) -> bool:
bool: Whether the checkpoint was loaded successfully.
"""

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
self._ft_load()

if not self.enable:
Expand Down Expand Up @@ -749,7 +753,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:

states_to_load = self._flattened_model_states_sd(states_to_load)

if self.ft_manager:
if self.enable_ft_dataloader_checkpoints:
states_to_load.pop(DATALOADER)

return states_to_load
Expand Down Expand Up @@ -805,7 +809,9 @@ def _async_wait(self) -> None:
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
if self.save_future is not None:
self.save_future.result()
elif self.async_mode == AsyncMode.ASYNC or self.ft_manager is not None:
elif (
self.async_mode == AsyncMode.ASYNC or self.enable_ft_dataloader_checkpoints
):
if self.save_future is not None:
self.save_future.result()
self.save_future = None
Expand All @@ -820,7 +826,10 @@ def _purge_stale_checkpoints(self):
self.keep_latest_k > 0
and dist.get_rank() == 0
and os.path.isdir(self.folder)
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
and (
not self.enable_ft_dataloader_checkpoints
or (self.ft_manager and self.ft_manager.participating_rank() == 0)
)
):
discovered_checkpoints = []
for filename in os.listdir(self.folder):
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def maybe_enable_amp(


def init_distributed(
comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = ""
comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "", ranks: list[int] = []
):
def _warn_overwrite_env(env, val):
if env in os.environ:
Expand Down Expand Up @@ -276,6 +276,7 @@ def _get_distributed_backend(enable_cpu_backend):
torch.distributed.init_process_group(
backend=_get_distributed_backend(enable_cpu_backend),
timeout=timedelta(seconds=comm_config.init_timeout_seconds),
_ranks=ranks,
)


Expand Down
11 changes: 6 additions & 5 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,13 @@ def train(self):
self.checkpointer.load(step=job_config.checkpoint.load_step)
logger.info(f"Training starts at step {self.step + 1}.")

torch_profiler = maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
)

with (
maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
) as torch_profiler,
maybe_enable_memory_snapshot(
job_config.profiling,
global_step=self.step,
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000


@contextlib.contextmanager
def maybe_enable_profiling(
profiling_config: ProfilingConfig,
*,
Expand Down Expand Up @@ -68,20 +67,20 @@ def trace_handler(prof):
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
elif torch.xpu.is_available():
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
with torch.profiler.profile(
torch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
gpu_device_profiled,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
on_trace_ready=trace_handler,
record_shapes=True,
) as torch_profiler:
torch_profiler.step_num = global_step
yield torch_profiler
)
torch_profiler.step_num = global_step
torch_profiler.start()
return torch_profiler
else:
torch_profiler = contextlib.nullcontext()
yield None
return None


@contextlib.contextmanager
Expand Down
44 changes: 36 additions & 8 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import ctypes
import importlib
import os
import signal
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional
Expand All @@ -32,8 +34,12 @@
maybe_enable_profiling,
)

c_globals = ctypes.CDLL(None) # POSIX


class Trainer(torch.distributed.checkpoint.stateful.Stateful):
torch_profiler: torch.profiler.profile | None = None

# core configs
job_config: JobConfig
parallel_dims: ParallelDims
Expand Down Expand Up @@ -83,11 +89,21 @@ def __init__(self, job_config: JobConfig):
# Device has to be set before creating TorchFT manager.
device_module.set_device(self.device)

ranks = []
ft_config = job_config.fault_tolerance
if ft_config.enable:
group_size = ft_config.group_size
replica_id = ft_config.replica_id
first_rank = replica_id * group_size
last_rank = first_rank + group_size - 1
ranks = list(range(first_rank, last_rank + 1))

# init distributed and build meshes
dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
ranks=ranks,
)

job_config.maybe_log()
Expand Down Expand Up @@ -570,13 +586,14 @@ def train(self):
if not self.ft_manager.enabled
else f"replica_{self.ft_manager.replica_id}"
)
self.torch_profiler = maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
leaf_folder=leaf_folder,
)

with (
maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
leaf_folder=leaf_folder,
) as torch_profiler,
maybe_enable_memory_snapshot(
job_config.profiling,
global_step=self.step,
Expand All @@ -600,6 +617,15 @@ def train(self):
),
),
):
if self.torch_profiler:

@ctypes.CFUNCTYPE(None, ctypes.c_int)
def sigabrt_handler(signal):
logger.info("SIGABRT received. Stopping profiler")
self.torch_profiler.export_chrome_trace("trace.json")

c_globals.signal(signal.SIGABRT, sigabrt_handler)

data_iterator = self.batch_generator(self.dataloader)
while self.should_continue_training():
self.step += 1
Expand All @@ -623,8 +649,8 @@ def train(self):
self.validator.validate(self.model_parts, self.step)

# signal the profiler that the next profiling step has started
if torch_profiler:
torch_profiler.step()
if self.torch_profiler:
self.torch_profiler.step()
if memory_profiler:
memory_profiler.step()

Expand Down Expand Up @@ -682,10 +708,12 @@ def close(self) -> None:
else:
trainer.train()
except Exception:
logger.info("Torchtitan training threw an exception")
if trainer:
trainer.close()
raise
else:
logger.info("Torchtitan training completed")
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")
Loading