Skip to content

Commit 7de581c

Browse files
committed
Merge branch 'update-docs' of https://github.com/apoorvkh/torchrunx into update-docs
2 parents c58cb11 + 8ad727d commit 7de581c

File tree

5 files changed

+28
-10
lines changed

5 files changed

+28
-10
lines changed

docs/source/examples/lightning_example.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from transformers import AutoModelForCausalLM, AutoTokenizer
1111

1212
import torchrunx
13-
13+
from torchrunx.ext.lightning import TorchrunxClusterEnvironment
1414

1515
class GPT2CausalLMDataset(Dataset):
1616
def __init__(self, text_dataset):
@@ -47,7 +47,7 @@ def __init__(self):
4747
super().__init__()
4848
self.model = AutoModelForCausalLM.from_pretrained("gpt2")
4949

50-
def training_step(self, batch, batch_idx):
50+
def training_step(self, batch, *args): # pyright: ignore
5151
device_batch = {k: v.to(self.model.device) for k, v in batch.items()}
5252
loss = self.model(**device_batch).loss
5353
self.log("train_loss", loss)
@@ -72,24 +72,23 @@ def train():
7272
devices=2,
7373
num_nodes=1,
7474
strategy="ddp",
75+
plugins=[TorchrunxClusterEnvironment()],
76+
enable_checkpointing=False
7577
)
7678

7779
trainer.fit(model=lightning_model, train_dataloaders=train_loader)
80+
checkpoint = f"{trainer.log_dir}/final.ckpt"
81+
trainer.save_checkpoint(checkpoint)
7882

79-
if int(os.environ["RANK"]) == 0:
80-
return trainer.model.model
81-
return None
83+
return checkpoint
8284

8385

8486
if __name__ == "__main__":
85-
# hack to prevent lightning from recognizing SLURM environment...
86-
os.environ["SLURM_JOB_NAME"] = "bash"
87-
Path("output").mkdir(exist_ok=True)
8887
results = torchrunx.launch(
8988
func=train,
9089
hostnames=["localhost"],
9190
workers_per_host=2,
9291
)
9392

94-
trained_model: nn.Module = results.rank(0)
95-
torch.save(trained_model.state_dict(), "output/model.pth")
93+
checkpoint_path = results.rank(0)
94+
print(f"Checkpoint at: {checkpoint_path}")

src/torchrunx/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
8383
backend=launcher_payload.backend,
8484
rank=worker_global_ranks[i],
8585
local_rank=i,
86+
node_rank=agent_rank,
8687
local_world_size=num_workers,
8788
world_size=worker_world_size,
8889
hostname=launcher_payload.hostnames[agent_rank],

src/torchrunx/ext/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Extensions classes and functions."""

src/torchrunx/ext/lightning.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Pytorch Lightning extension utilities."""
2+
3+
import torch
4+
from lightning.fabric.plugins.environments.torchelastic import ( # pyright: ignore [reportMissingImports]
5+
TorchElasticEnvironment,
6+
)
7+
8+
9+
class TorchrunxClusterEnvironment(TorchElasticEnvironment):
10+
"""PyTorch Lightning ClusterEnvironment compatible with torchrunx."""
11+
12+
@staticmethod
13+
def detect() -> bool:
14+
"""Returns ``True`` if the current process was launched using torchrunx."""
15+
return torch.distributed.is_available()

src/torchrunx/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class WorkerArgs:
3232
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
3333
rank: int
3434
local_rank: int
35+
node_rank: int
3536
local_world_size: int
3637
world_size: int
3738
hostname: str
@@ -79,6 +80,7 @@ def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | Exc
7980

8081
os.environ["RANK"] = str(worker_args.rank)
8182
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
83+
os.environ["GROUP_RANK"] = str(worker_args.node_rank)
8284
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
8385
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
8486
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname

0 commit comments

Comments
 (0)