1
+ import copy
2
+
3
+ import submitit
4
+ import torch
5
+ from torch .utils .data import Dataset
6
+ from transformers import BertForMaskedLM , Trainer , TrainingArguments
7
+
8
+ import torchrunx as trx
9
+
10
+
11
+ class DummyDataset (Dataset ):
12
+ def __init__ (self , max_text_length = 16 , num_samples = 20000 ) -> None :
13
+ super ().__init__ ()
14
+ self .input_ids = torch .randint (0 , 30522 , (num_samples , max_text_length ))
15
+ self .labels = copy .deepcopy (self .input_ids )
16
+
17
+ def __len__ (self ):
18
+ return len (self .input_ids )
19
+
20
+ def __getitem__ (self , index ):
21
+ return {
22
+ "input_ids" : self .input_ids [index ],
23
+ "labels" : self .labels [index ],
24
+ }
25
+
26
+ def main ():
27
+ model = BertForMaskedLM .from_pretrained ("bert-base-uncased" )
28
+ train_dataset = DummyDataset ()
29
+
30
+ ## Training
31
+
32
+ training_arguments = TrainingArguments (
33
+ output_dir = "output" ,
34
+ do_train = True ,
35
+ per_device_train_batch_size = 16 ,
36
+ max_steps = 100 ,
37
+ )
38
+
39
+ trainer = Trainer (
40
+ model = model , # type: ignore
41
+ args = training_arguments ,
42
+ train_dataset = train_dataset
43
+ )
44
+
45
+ trainer .train ()
46
+
47
+ def launch ():
48
+ trx .launch (
49
+ func = main ,
50
+ func_kwargs = {},
51
+ hostnames = trx .slurm_hosts (),
52
+ workers_per_host = trx .slurm_workers ()
53
+ )
54
+
55
+ if __name__ == "__main__" :
56
+ executor = submitit .SlurmExecutor (folder = "logs" )
57
+
58
+ executor .update_parameters (
59
+ time = 60 ,
60
+ nodes = 1 ,
61
+ ntasks_per_node = 1 ,
62
+ mem = "32G" ,
63
+ cpus_per_task = 4 ,
64
+ gpus_per_node = 2 ,
65
+ constraint = "geforce3090" ,
66
+ partition = "3090-gcondo" ,
67
+ stderr_to_stdout = True ,
68
+ use_srun = False ,
69
+ )
70
+
71
+ executor .submit (launch )
0 commit comments