Skip to content

Commit c1e0689

Browse files
committed
Checkpoint a set number of invidividual Transformer layers
consider the case of pipeline-model prallelism clean up arugments argument naming cleanup update readme and examples
1 parent 68797d9 commit c1e0689

12 files changed

+76
-31
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ OUTPUT_ARGS="--log-interval 10 \
156156
--save-interval 500 \
157157
--eval-interval 100 \
158158
--eval-iters 10 \
159-
--checkpoint-activations"
159+
--activations-checkpoint-method uniform"
160160

161161
python pretrain_bert.py \
162162
$BERT_ARGS \
@@ -345,7 +345,7 @@ python pretrain_ict.py \
345345
--max-position-embeddings 256 \
346346
--ict-head-size 128 \
347347
--train-iters 100000 \
348-
--checkpoint-activations \
348+
--activations-checkpoint-method uniform \
349349
--bert-load /path/to/pretrained_bert \
350350
--load checkpoints \
351351
--save checkpoints \
@@ -375,7 +375,7 @@ python tools/create_doc_index.py \
375375
--ict-head-size 128 \
376376
--num-attention-heads 12 \
377377
--batch-size 128 \
378-
--checkpoint-activations \
378+
--activations-checkpoint-method uniform \
379379
--seq-length 256 \
380380
--max-position-embeddings 256 \
381381
--ict-load /path/to/pretrained_ict \
@@ -482,7 +482,7 @@ python tasks/main.py \
482482
--merge-file $MERGE_FILE \
483483
--load $CHECKPOINT_PATH \
484484
--micro-batch-size 8 \
485-
--checkpoint-activations \
485+
--activations-checkpoint-method uniform \
486486
--log-interval 10 \
487487
--no-load-optim \
488488
--no-load-rng
@@ -512,7 +512,7 @@ python tasks/main.py \
512512
--merge-file $MERGE_FILE \
513513
--load $CHECKPOINT_PATH \
514514
--micro-batch-size 8 \
515-
--checkpoint-activations \
515+
--activations-checkpoint-method uniform \
516516
--log-interval 10 \
517517
--no-load-optim \
518518
--no-load-rng
@@ -542,7 +542,7 @@ COMMON_TASK_ARGS="--num-layers 24 \
542542
COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
543543
--valid-data $VALID_DATA \
544544
--pretrained-checkpoint $PRETRAINED_CHECKPOINT \
545-
--checkpoint-activations \
545+
--activations-checkpoint-method uniform \
546546
--save-interval 10000 \
547547
--save $CHECKPOINT_PATH \
548548
--log-interval 100 \

examples/evaluate_retriever_nq.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ python tasks/main.py \
2020
--num-attention-heads 12 \
2121
--tensor-model-parallel-size 1 \
2222
--micro-batch-size 128 \
23-
--checkpoint-activations \
23+
--activations-checkpoint-method uniform \
2424
--seq-length 512 \
2525
--max-position-embeddings 512 \
2626
--load ${CHECKPOINT_PATH} \

examples/evaluate_zeroshot_gpt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
2929
--hidden-size 1024 \
3030
--num-attention-heads 16 \
3131
--batch-size 8 \
32-
--checkpoint-activations \
32+
--activations-checkpoint-method uniform \
3333
--seq-length 1024 \
3434
--max-position-embeddings 1024 \
3535
--log-interval 10 \

examples/finetune_mnli_distributed.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
2929
--hidden-size 1024 \
3030
--num-attention-heads 16 \
3131
--micro-batch-size 8 \
32-
--checkpoint-activations \
32+
--activations-checkpoint-method uniform \
3333
--lr 5.0e-5 \
3434
--lr-decay-style linear \
3535
--lr-warmup-fraction 0.065 \

examples/finetune_race_distributed.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
2929
--hidden-size 1024 \
3030
--num-attention-heads 16 \
3131
--micro-batch-size 4 \
32-
--checkpoint-activations \
32+
--activations-checkpoint-method uniform \
3333
--lr 1.0e-5 \
3434
--lr-decay-style linear \
3535
--lr-warmup-fraction 0.06 \

examples/pretrain_gpt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ python pretrain_gpt.py \
3333
--weight-decay 1e-2 \
3434
--clip-grad 1.0 \
3535
--lr-warmup-fraction .01 \
36-
--checkpoint-activations \
36+
--activations-checkpoint-method uniform \
3737
--log-interval 100 \
3838
--save-interval 10000 \
3939
--eval-interval 1000 \

examples/pretrain_gpt3_175B.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ options=" \
4949
--init-method-std 0.006 \
5050
--tensorboard-dir <TENSORBOARD DIRECTORY> \
5151
--fp16 \
52-
--checkpoint-activations "
52+
--activations-checkpoint-method uniform "
5353

5454

5555
run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}"

examples/pretrain_gpt_distributed.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
4040
--weight-decay 1e-2 \
4141
--clip-grad 1.0 \
4242
--lr-warmup-fraction .01 \
43-
--checkpoint-activations \
43+
--activations-checkpoint-method uniform \
4444
--log-interval 100 \
4545
--save-interval 10000 \
4646
--eval-interval 1000 \

examples/pretrain_gpt_distributed_with_mp.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
4242
--weight-decay 1e-2 \
4343
--clip-grad 1.0 \
4444
--lr-warmup-fraction .01 \
45-
--checkpoint-activations \
45+
--activations-checkpoint-method uniform \
4646
--log-interval 100 \
4747
--save-interval 10000 \
4848
--eval-interval 1000 \

megatron/arguments.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def parse_args(extra_args_provider=None, defaults={},
9191
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
9292
'longer valid, use --tensor-model-parallel-size instead'
9393
del args.model_parallel_size
94+
if args.checkpoint_activations:
95+
print('--checkpoint-activations is no longer valid, '
96+
'use --activation-checkpoint-method instead. '
97+
'Defaulting to activation-checkpoint-method=uniform.')
98+
args.activations_checkpoint_method = 'uniform'
99+
del args.checkpoint_activations
94100

95101
# Set input defaults.
96102
for key in defaults:
@@ -234,9 +240,9 @@ def parse_args(extra_args_provider=None, defaults={},
234240
'residual connection in fp32 only supported when using fp16 or bf16.'
235241
# Activation checkpointing.
236242
if args.distribute_checkpointed_activations:
237-
assert args.checkpoint_activations, \
243+
assert args.activations_checkpoint_method is not None, \
238244
'for distribute-checkpointed-activations to work you '\
239-
'need to enable checkpoint-activations'
245+
'need to use a valid checkpoint-activation method (\'uniform\' or \'block\')'
240246

241247
_print_args(args)
242248
return args
@@ -402,8 +408,19 @@ def _add_training_args(parser):
402408
action='store_true',
403409
help='If set, distribute checkpointed activations '
404410
'across model parallel group.')
405-
group.add_argument('--checkpoint-num-layers', type=int, default=1,
406-
help='chunk size (number of layers) for checkpointing.')
411+
group.add_argument('--activations-checkpoint-method', type=str, default=None,
412+
choices=['uniform', 'block'],
413+
help='1) uniform: uniformly divide the total number of '
414+
'Transformer layers and checkpoint the input activation of '
415+
'each divided chunk, '
416+
'2) block: checkpoint the input activation of only a set '
417+
'number of individual Transformer layers and skip the rest, '
418+
'default) checkpoint the inputs of every Transformer layer')
419+
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
420+
help='1) uniform: the number of Transformer layers in each '
421+
'uniformly divided checkpoint unit, '
422+
'2) block: the number of individual Transformer layers '
423+
'to checkpoint within each pipeline stage.')
407424
group.add_argument('--train-iters', type=int, default=None,
408425
help='Total number of iterations to train over all '
409426
'training runs. Note that either train-iters or '

megatron/model/transformer.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,8 @@ def __init__(self, init_method, output_layer_init_method,
542542
self.input_tensor = None
543543

544544
# Store activation checkpoiting flag.
545-
self.checkpoint_activations = args.checkpoint_activations
546-
self.checkpoint_num_layers = args.checkpoint_num_layers
545+
self.activations_checkpoint_method = args.activations_checkpoint_method
546+
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
547547

548548
# Number of layers.
549549
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
@@ -609,12 +609,31 @@ def custom_forward(*inputs):
609609

610610
# Make sure memory is freed.
611611
mpu.reset_checkpointed_activations_memory_buffer()
612-
l = 0
613-
while l < self.num_layers:
614-
hidden_states = mpu.checkpoint(
615-
custom(l, l + self.checkpoint_num_layers),
616-
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
617-
l += self.checkpoint_num_layers
612+
613+
if self.activations_checkpoint_method == 'uniform':
614+
# Uniformly divide the total number of Transformer layers and checkpoint
615+
# the input activation of each divided chunk.
616+
# A method to further reduce memory usage reducing checkpoints.
617+
l = 0
618+
while l < self.num_layers:
619+
hidden_states = mpu.checkpoint(
620+
custom(l, l + self.activations_checkpoint_num_layers),
621+
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
622+
l += self.activations_checkpoint_num_layers
623+
elif self.activations_checkpoint_method == 'block':
624+
# Checkpoint the input activation of only a set number of individual
625+
# Transformer layers and skip the rest.
626+
# A method fully use the device memory removing redundant re-computation.
627+
for l in range(self.num_layers):
628+
if l < self.activations_checkpoint_num_layers:
629+
hidden_states = mpu.checkpoint(
630+
custom(l, l + 1),
631+
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
632+
else:
633+
hidden_states = custom(l, l + 1)(
634+
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
635+
else:
636+
raise ValueError("Invalid activation checkpoint method.")
618637

619638
return hidden_states
620639

@@ -637,7 +656,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
637656
'for not None values in layer_past, ' \
638657
'expected get_key_value to be set'
639658
if get_key_value:
640-
assert not self.checkpoint_activations, \
659+
assert self.activations_checkpoint_method is None, \
641660
'get_key_value does not work with ' \
642661
'activation checkpointing'
643662

@@ -656,7 +675,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
656675
if encoder_output is not None:
657676
encoder_output = encoder_output.transpose(0, 1).contiguous()
658677

659-
if self.checkpoint_activations:
678+
if self.activations_checkpoint_method is not None:
660679
hidden_states = self._checkpointed_forward(hidden_states,
661680
attention_mask,
662681
encoder_output,

megatron/mpu/random.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,18 @@ def init_checkpointed_activations_memory_buffer():
4747

4848
per_layer = args.micro_batch_size * args.max_position_embeddings * \
4949
args.hidden_size // args.tensor_model_parallel_size
50-
assert args.num_layers % args.checkpoint_num_layers == 0, \
51-
'number of layers is not divisible by checkpoint-num-layers'
52-
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
50+
num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
51+
if args.virtual_pipeline_model_parallel_size is not None:
52+
num_layers = num_layers // args.virtual_pipeline_model_parallel_size
53+
54+
if args.activations_checkpoint_method == 'uniform':
55+
assert num_layers % args.activations_checkpoint_num_layers == 0, \
56+
'total number of layers is not divisible by checkpoint-chunk_size'
57+
num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers
58+
elif args.activations_checkpoint_method == 'block':
59+
assert args.activations_checkpoint_num_layers <= num_layers, \
60+
'total number of layers is fewer than the number of layers to checkpoint'
61+
num_checkpointer_layers = args.activations_checkpoint_num_layers
5362
numel = per_layer * num_checkpointer_layers
5463
dtype = torch.half
5564
if not args.fp16:

0 commit comments

Comments
 (0)