Skip to content

Commit 8b6dab2

Browse files
authored
Merge pull request #36 from apoorvkh/submitit
submitit fix
2 parents 0b1f587 + 6c774a1 commit 8b6dab2

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

examples/submitit_train.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import copy
2+
3+
import submitit
4+
import torch
5+
from torch.utils.data import Dataset
6+
from transformers import BertForMaskedLM, Trainer, TrainingArguments
7+
8+
import torchrunx as trx
9+
10+
11+
class DummyDataset(Dataset):
12+
def __init__(self, max_text_length=16, num_samples=20000) -> None:
13+
super().__init__()
14+
self.input_ids = torch.randint(0, 30522, (num_samples, max_text_length))
15+
self.labels = copy.deepcopy(self.input_ids)
16+
17+
def __len__(self):
18+
return len(self.input_ids)
19+
20+
def __getitem__(self, index):
21+
return {
22+
"input_ids": self.input_ids[index],
23+
"labels": self.labels[index],
24+
}
25+
26+
def main():
27+
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
28+
train_dataset = DummyDataset()
29+
30+
## Training
31+
32+
training_arguments = TrainingArguments(
33+
output_dir = "output",
34+
do_train = True,
35+
per_device_train_batch_size = 16,
36+
max_steps = 100,
37+
)
38+
39+
trainer = Trainer(
40+
model=model, # type: ignore
41+
args=training_arguments,
42+
train_dataset=train_dataset
43+
)
44+
45+
trainer.train()
46+
47+
def launch():
48+
trx.launch(
49+
func=main,
50+
func_kwargs={},
51+
hostnames=trx.slurm_hosts(),
52+
workers_per_host=trx.slurm_workers()
53+
)
54+
55+
if __name__ == "__main__":
56+
executor = submitit.SlurmExecutor(folder="logs")
57+
58+
executor.update_parameters(
59+
time=60,
60+
nodes=1,
61+
ntasks_per_node=1,
62+
mem="32G",
63+
cpus_per_task=4,
64+
gpus_per_node=2,
65+
constraint="geforce3090",
66+
partition="3090-gcondo",
67+
stderr_to_stdout=True,
68+
use_srun=False,
69+
)
70+
71+
executor.submit(launch)

src/torchrunx/launcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def launch(
191191
ssh_config_file=ssh_config_file,
192192
)
193193
raise
194-
#
194+
finally:
195+
print_process.kill()
195196

196-
print_process.terminate() # TODO: or close?
197197
return_values: dict[int, Any] = dict(ChainMap(*[s.return_values for s in agent_statuses]))
198198
return return_values

0 commit comments

Comments
 (0)