Skip to content

Commit 6dba84b

Browse files
lmcafee-nvidiajaredcasper
authored andcommitted
Retro updates.
1 parent cd2537d commit 6dba84b

36 files changed

+1282
-2653
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ build
44
.coverage_*
55
*.egg-info
66
*~
7+
slurm*
8+
logs

megatron/core/enums.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
22

33
import enum
44

55
class ModelType(enum.Enum):
66
encoder_or_decoder = 1
77
encoder_and_decoder = 2
8+
retro_encoder = 3
9+
retro_decoder = 4

megatron/data/gpt_dataset.py

+72-72
Original file line numberDiff line numberDiff line change
@@ -308,84 +308,84 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
308308
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
309309

310310
# Build the indexed mapping if not exist.
311-
if torch.distributed.get_rank() == 0:
312-
if (not os.path.isfile(doc_idx_filename)) or \
313-
(not os.path.isfile(sample_idx_filename)) or \
314-
(not os.path.isfile(shuffle_idx_filename)):
311+
if torch.distributed.get_rank() == 0 and \
312+
(not os.path.isfile(doc_idx_filename) or
313+
not os.path.isfile(sample_idx_filename) or
314+
not os.path.isfile(shuffle_idx_filename)):
315315

316-
print_rank_0(' > WARNING: could not find index map files, building '
317-
'the indices on rank 0 ...')
316+
print_rank_0(' > WARNING: could not find index map files, building '
317+
'the indices on rank 0 ...')
318318

319-
# For the last epoch, decide whether include the entire epoch
320-
# in the global shuffle or not.
319+
# For the last epoch, decide whether include the entire epoch
320+
# in the global shuffle or not.
321321

322-
# If we need only one epoch, then separating last epoch does
323-
# not mean anything.
324-
if num_epochs == 1:
325-
separate_last_epoch = False
326-
print(' > only one epoch required, setting '
327-
'separate_last_epoch to False', flush=True)
322+
# If we need only one epoch, then separating last epoch does
323+
# not mean anything.
324+
if num_epochs == 1:
325+
separate_last_epoch = False
326+
print(' > only one epoch required, setting '
327+
'separate_last_epoch to False', flush=True)
328328

329-
else:
330-
# Get the number of samples for the last epoch
331-
num_samples_from_epochs_minus_one = (
332-
(num_epochs - 1) * tokens_per_epoch - 1) // seq_length
333-
last_epoch_num_samples = num_samples - \
334-
num_samples_from_epochs_minus_one
335-
assert last_epoch_num_samples >= 0, \
336-
'last epoch number of samples should be non-negative.'
337-
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
338-
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
339-
'last epoch number of samples exceeded max value.'
340-
# If we have less than 80% of the samples for the last epoch,
341-
# seperate out the epoch and treat it differently.
342-
# Note: the 80% number is just based on common sense and can
343-
# be adjusted if needed.
344-
separate_last_epoch = (last_epoch_num_samples <
345-
int(0.80 * num_samples_per_epoch))
346-
if separate_last_epoch:
347-
string = ' > last epoch number of samples ({}) is smaller '\
348-
'than 80% of number of samples per epoch ({}), '\
349-
'setting separate_last_epoch to True'
350-
else:
351-
string = ' > last epoch number of samples ({}) is larger '\
352-
'than 80% of number of samples per epoch ({}), '\
353-
'setting separate_last_epoch to False'
354-
print(string.format(last_epoch_num_samples,
355-
num_samples_per_epoch), flush=True)
356-
357-
# doc-idx.
358-
start_time = time.time()
359-
doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
360-
separate_last_epoch)
361-
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
362-
print_rank_0(' > elasped time to build and save doc-idx mapping '
363-
'(seconds): {:4f}'.format(time.time() - start_time))
364-
# sample-idx.
365-
start_time = time.time()
366-
# Use C++ implementation for speed.
367-
# First compile and then import.
368-
from megatron.data import helpers
369-
assert doc_idx.dtype == np.int32
370-
assert sizes.dtype == np.int32
371-
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
372-
num_epochs, tokens_per_epoch)
373-
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
374-
print_rank_0(' > elasped time to build and save sample-idx mapping '
375-
'(seconds): {:4f}'.format(time.time() - start_time))
376-
# shuffle-idx.
377-
start_time = time.time()
378-
# -1 is due to data structure used to retieve the index:
379-
# sample i --> [sample_idx[i], sample_idx[i+1])
329+
else:
330+
# Get the number of samples for the last epoch
331+
num_samples_from_epochs_minus_one = (
332+
(num_epochs - 1) * tokens_per_epoch - 1) // seq_length
333+
last_epoch_num_samples = num_samples - \
334+
num_samples_from_epochs_minus_one
335+
assert last_epoch_num_samples >= 0, \
336+
'last epoch number of samples should be non-negative.'
337+
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
338+
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
339+
'last epoch number of samples exceeded max value.'
340+
# If we have less than 80% of the samples for the last epoch,
341+
# seperate out the epoch and treat it differently.
342+
# Note: the 80% number is just based on common sense and can
343+
# be adjusted if needed.
344+
separate_last_epoch = (last_epoch_num_samples <
345+
int(0.80 * num_samples_per_epoch))
380346
if separate_last_epoch:
381-
num_samples_ = num_samples_from_epochs_minus_one
347+
string = ' > last epoch number of samples ({}) is smaller '\
348+
'than 80% of number of samples per epoch ({}), '\
349+
'setting separate_last_epoch to True'
382350
else:
383-
num_samples_ = sample_idx.shape[0] - 1
384-
shuffle_idx = _build_shuffle_idx(num_samples_,
385-
sample_idx.shape[0] - 1, np_rng)
386-
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
387-
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
388-
' (seconds): {:4f}'.format(time.time() - start_time))
351+
string = ' > last epoch number of samples ({}) is larger '\
352+
'than 80% of number of samples per epoch ({}), '\
353+
'setting separate_last_epoch to False'
354+
print(string.format(last_epoch_num_samples,
355+
num_samples_per_epoch), flush=True)
356+
357+
# doc-idx.
358+
start_time = time.time()
359+
doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
360+
separate_last_epoch)
361+
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
362+
print_rank_0(' > elasped time to build and save doc-idx mapping '
363+
'(seconds): {:4f}'.format(time.time() - start_time))
364+
# sample-idx.
365+
start_time = time.time()
366+
# Use C++ implementation for speed.
367+
# First compile and then import.
368+
from megatron.data import helpers
369+
assert doc_idx.dtype == np.int32
370+
assert sizes.dtype == np.int32
371+
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
372+
num_epochs, tokens_per_epoch)
373+
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
374+
print_rank_0(' > elasped time to build and save sample-idx mapping '
375+
'(seconds): {:4f}'.format(time.time() - start_time))
376+
# shuffle-idx.
377+
start_time = time.time()
378+
# -1 is due to data structure used to retieve the index:
379+
# sample i --> [sample_idx[i], sample_idx[i+1])
380+
if separate_last_epoch:
381+
num_samples_ = num_samples_from_epochs_minus_one
382+
else:
383+
num_samples_ = sample_idx.shape[0] - 1
384+
shuffle_idx = _build_shuffle_idx(num_samples_,
385+
sample_idx.shape[0] - 1, np_rng)
386+
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
387+
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
388+
' (seconds): {:4f}'.format(time.time() - start_time))
389389

390390
# This should be a barrier but nccl barrier assumes
391391
# device_index=rank which is not the case for model

megatron/data/indexed_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def write_longs(f, a):
9595
3: np.int16,
9696
4: np.int32,
9797
5: np.int64,
98-
6: np.float,
98+
6: np.float32,
9999
7: np.double,
100100
8: np.uint16
101101
}

megatron/model/enums.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
22

33
import enum
44

55
class LayerType(enum.Enum):
66
encoder = 1
77
decoder = 2
8+
retro_encoder = 3
9+
retro_decoder = 4
10+
retro_decoder_with_retriever = 5
811

912
class AttnType(enum.Enum):
1013
self_attn = 1

megatron/model/gpt_model.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
22

33
"""GPT-2 model."""
44

@@ -77,16 +77,18 @@ def set_input_tensor(self, input_tensor):
7777
self.language_model.set_input_tensor(input_tensor)
7878

7979
def forward(self, input_ids, position_ids, attention_mask,
80-
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
80+
retriever_input_ids=None,
81+
retriever_position_ids=None,
82+
retriever_attn_mask=None,
8183
labels=None, tokentype_ids=None, inference_params=None):
8284

8385
lm_output = self.language_model(
8486
input_ids,
8587
position_ids,
8688
attention_mask,
87-
ret_input_ids=ret_input_ids,
88-
ret_position_ids=ret_position_ids,
89-
ret_attn_mask=ret_attn_mask,
89+
retriever_input_ids=retriever_input_ids,
90+
retriever_position_ids=retriever_position_ids,
91+
retriever_attn_mask=retriever_attn_mask,
9092
inference_params=inference_params)
9193

9294
if self.post_process:

megatron/model/language_model.py

+31-53
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
22

33
"""Transformer based language model."""
44

@@ -7,10 +7,10 @@
77

88
from megatron import get_args
99
from megatron.core import mpu, tensor_parallel
10+
from megatron.core.enums import ModelType
1011

11-
from .enums import LayerType, AttnMaskType
12+
from .enums import AttnMaskType, LayerType
1213
from .module import MegatronModule
13-
from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
1414
from .rotary_pos_embedding import apply_rotary_pos_emb, RotaryEmbedding
1515
from .transformer import ParallelTransformer
1616
from .utils import get_linear_layer
@@ -352,6 +352,7 @@ def __init__(self,
352352
self.decoder_attn_mask_type = decoder_attn_mask_type
353353
self.add_pooler = add_pooler
354354
self.encoder_hidden_state = None
355+
self.add_retriever = args.retro_add_retriever
355356
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
356357

357358
# Embeddings.
@@ -380,39 +381,18 @@ def __init__(self,
380381
# https://github.com/kingoflolz/mesh-transformer-jax/
381382
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
382383

383-
# Retriever (bi-directional transformer with cross attention)
384-
if args.retro_add_retriever:
385-
self.retriever = ParallelRetroEncoder(
384+
# Encoder (usually set to True, False if part of an encoder-decoder
385+
# architecture and in encoder-only stage).
386+
if self.add_encoder:
387+
self.encoder = ParallelTransformer(
386388
self.init_method,
387389
output_layer_init_method,
388-
self_attn_mask_type=AttnMaskType.padding,
390+
model_type=args.model_type if not args.retro_add_retriever \
391+
else ModelType.retro_decoder,
392+
self_attn_mask_type=self.encoder_attn_mask_type,
389393
pre_process=self.pre_process,
390-
post_process=False,
394+
post_process=self.post_process,
391395
)
392-
self._retriever_key = 'retriever'
393-
else:
394-
self.retriever = None
395-
396-
# Encoder (usually set to True, False if part of an encoder-decoder
397-
# architecture and in encoder-only stage).
398-
if self.add_encoder:
399-
if args.retro_add_retriever:
400-
self.encoder = ParallelRetroTransformer(
401-
self.init_method,
402-
output_layer_init_method,
403-
self_attn_mask_type=self.encoder_attn_mask_type,
404-
pre_process=self.pre_process,
405-
post_process=self.post_process,
406-
retriever=self.retriever,
407-
)
408-
else:
409-
self.encoder = ParallelTransformer(
410-
self.init_method,
411-
output_layer_init_method,
412-
self_attn_mask_type=self.encoder_attn_mask_type,
413-
pre_process=self.pre_process,
414-
post_process=self.post_process,
415-
)
416396
self._encoder_key = 'encoder'
417397
else:
418398
self.encoder = None
@@ -423,6 +403,7 @@ def __init__(self,
423403
self.decoder = ParallelTransformer(
424404
self.init_method,
425405
output_layer_init_method,
406+
model_type=args.model_type,
426407
layer_type=LayerType.decoder,
427408
self_attn_mask_type=self.decoder_attn_mask_type,
428409
pre_process=self.pre_process,
@@ -477,26 +458,29 @@ def set_input_tensor(self, input_tensor):
477458

478459
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
479460
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
480-
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
461+
retriever_input_ids=None,
462+
retriever_position_ids=None,
463+
retriever_attn_mask=None,
481464
enc_dec_attn_mask=None, tokentype_ids=None,
482465
inference_params=None,
483466
pooling_sequence_index=0,
484467
enc_hidden_states=None, output_enc_hidden=False):
485468

486-
# Retriever embedding.
487-
if self.retriever and self.pre_process:
488-
retriever_input = self.embedding(ret_input_ids, ret_position_ids,
489-
tokentype_ids=tokentype_ids)
490-
else:
491-
retriever_input = None
492-
493469
# Encoder embedding.
494470
if self.pre_process:
495471
encoder_input = self.embedding(enc_input_ids, enc_position_ids,
496472
tokentype_ids=tokentype_ids)
497473
else:
498474
encoder_input = None
499475

476+
# Retriever embedding.
477+
if self.add_retriever and self.pre_process:
478+
retriever_input = self.embedding(retriever_input_ids,
479+
retriever_position_ids,
480+
tokentype_ids=tokentype_ids)
481+
else:
482+
retriever_input = None
483+
500484
# Rotary positional embeddings
501485
rotary_pos_emb = None
502486
if self.use_rotary_position_embeddings:
@@ -509,19 +493,13 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
509493
# Run encoder.
510494
if enc_hidden_states is None:
511495
if self.encoder is not None:
512-
if self.retriever:
513-
encoder_output = self.encoder(
514-
encoder_input,
515-
enc_attn_mask,
516-
retriever_output=retriever_input,
517-
retriever_attn_mask=ret_attn_mask,
518-
inference_params=inference_params)
519-
else:
520-
encoder_output = self.encoder(
521-
encoder_input,
522-
enc_attn_mask,
523-
inference_params=inference_params,
524-
rotary_pos_emb=rotary_pos_emb)
496+
encoder_output = self.encoder(
497+
encoder_input,
498+
enc_attn_mask,
499+
retriever_input=retriever_input,
500+
retriever_attn_mask=retriever_attn_mask,
501+
inference_params=inference_params,
502+
rotary_pos_emb=rotary_pos_emb)
525503
else:
526504
encoder_output = self.encoder_hidden_state
527505
else:

0 commit comments

Comments
 (0)