From 7e4dc10f7d7856ac7ec1175f1462e069f56c435b Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 21 Feb 2024 11:43:46 -0800 Subject: [PATCH 01/22] Adding support for FP8 training --- open_lm/main.py | 6 +- open_lm/model.py | 143 +++++++++++++++++++++++++++++++++------------- open_lm/norms.py | 40 ++++++++++--- open_lm/params.py | 11 ++++ open_lm/train.py | 59 ++++++++++++++----- 5 files changed, 198 insertions(+), 61 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 6e3dc7a5..22c1165a 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -63,7 +63,6 @@ terminate_sync_process, ) - LATEST_CHECKPOINT_NAME = "epoch_latest.pt" @@ -439,6 +438,8 @@ def main(args): random_seed(args.seed, args.rank) + all_gpus = dist.new_group(backend='nccl') + if args.distributed: if args.fsdp: transformer_layer_cls = None @@ -498,12 +499,14 @@ def main(args): random_seed(args.seed, rank=0) model = FSDP( model, + process_group=all_gpus, auto_wrap_policy=transformer_auto_wrapper_policy, device_id=device, mixed_precision=mp_policy, cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), use_orig_params=args.fsdp_use_orig_params, limit_all_gathers=args.fsdp_limit_all_gathers, + sync_module_states=True, **fsdp_kwargs, ) @@ -754,6 +757,7 @@ def main(args): total_steps=total_steps, args=args, tb_writer=writer, + all_gpus=all_gpus ) if args.distributed: diff --git a/open_lm/model.py b/open_lm/model.py index 7484ca5d..d33a4482 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -35,6 +35,17 @@ except ImportError: MambaLMHeadModel = None +# Adding flag if using TE FP8 +using_te = False +try: + import transformer_engine.pytorch as te + from transformer_engine.common import recipe + fp8_format = recipe.Format.HYBRID + fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + using_te = True +except ImportError as ie: + using_te = False + # from openclip _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs @@ -117,41 +128,72 @@ def __init__(self, layer_id, args: Params): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads - self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) - self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + if using_te: + self.in_proj = te.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device='cuda') + self.out_proj = te.Linear(args.n_heads * self.head_dim, args.dim, bias=False, device='cuda') + else: + self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) + self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func self.apply_qk_norm = args.apply_qk_norm - # initialize norm layers for queries and keys if needed - self.q_norm = ( - args.norm_type( - args.n_heads * self.head_dim, - eps=args.norm_eps, + if using_te: + # initialize norm layers for queries and keys if needed + self.q_norm = ( + te.LayerNorm( + args.n_heads * self.head_dim, + eps=args.norm_eps, + ) + if self.apply_qk_norm + else nn.Identity() ) - if self.apply_qk_norm - else nn.Identity() - ) - self.k_norm = ( - args.norm_type( - args.n_heads * self.head_dim, - eps=args.norm_eps, + self.k_norm = ( + te.LayerNorm( + args.n_heads * self.head_dim, + eps=args.norm_eps, + ) + if self.apply_qk_norm + else nn.Identity() + ) + else: + # initialize norm layers for queries and keys if needed + self.q_norm = ( + args.norm_type( + args.n_heads * self.head_dim, + eps=args.norm_eps, + ) + if self.apply_qk_norm + else nn.Identity() + ) + self.k_norm = ( + args.norm_type( + args.n_heads * self.head_dim, + eps=args.norm_eps, + ) + if self.apply_qk_norm + else nn.Identity() ) - if self.apply_qk_norm - else nn.Identity() - ) self.layer_id = layer_id self.dim = args.dim self.reset_parameters() def reset_parameters(self): - # initialize weights by trunc_normal(1/sqrt(fan_in)) - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std) - # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. - std = std / math.sqrt(2 * (self.layer_id + 1)) - torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) + if using_te: + # initialize weights by trunc_normal(1/sqrt(fan_in)) + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.in_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) + # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. + std = std / math.sqrt(2 * (self.layer_id + 1)) + torch.nn.init.trunc_normal_(self.out_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) + else: + # initialize weights by trunc_normal(1/sqrt(fan_in)) + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std) + # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. + std = std / math.sqrt(2 * (self.layer_id + 1)) + torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False): batchsize, q_len, _ = x.shape @@ -202,9 +244,14 @@ def __init__(self, layer_id, args: Params): elif args.ffn_type == "gelu": # Follows mosaic mpt7b, but without a bias. self.hidden_dim = args.dim * 4 - self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) - self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) - self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) + if using_te: + self._ff_w1 = te.Linear(args.dim, self.hidden_dim, bias=False, device='cuda') + self._ff_w2 = te.Linear(self.hidden_dim, self.dim, bias=False, device='cuda') + self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) + else: + self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) + self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) + self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) elif args.ffn_type == "moe": moe_args = MoEArgs( hidden_size=args.dim, @@ -222,14 +269,24 @@ def __init__(self, layer_id, args: Params): self.feed_forward = MoE(moe_args) self.layer_id = layer_id - self.attention_norm = args.norm_type( - args.dim, - eps=args.norm_eps, - ) - self.ffn_norm = args.norm_type( - args.dim, - eps=args.norm_eps, - ) + if using_te: + self.attention_norm = te.LayerNorm( + args.dim, + eps=args.norm_eps, + ) + self.ffn_norm = te.LayerNorm( + args.dim, + eps=args.norm_eps, + ) + else: + self.attention_norm = args.norm_type( + args.dim, + eps=args.norm_eps, + ) + self.ffn_norm = args.norm_type( + args.dim, + eps=args.norm_eps, + ) self.attention.seq_len = args.seq_len self.reset_parameters() @@ -243,12 +300,20 @@ def reset_parameters(self): std = std / math.sqrt(2 * (self.layer_id + 1)) torch.nn.init.trunc_normal_(self.feed_forward.w3.weight, std=std, a=-3 * std, b=3 * std) elif self._ffn_type == "gelu": - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std) + if using_te: + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self._ff_w1.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) - std = 1.0 / math.sqrt(self.hidden_dim) - std = std / math.sqrt(2 * (self._layer_id + 1)) - torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) + std = 1.0 / math.sqrt(self.hidden_dim) + std = std / math.sqrt(2 * (self._layer_id + 1)) + torch.nn.init.trunc_normal_(self._ff_w2.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) + else: + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std) + + std = 1.0 / math.sqrt(self.hidden_dim) + std = std / math.sqrt(2 * (self._layer_id + 1)) + torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x, past_key_value=None, use_cache=False): h, past_key_value = self.attention( diff --git a/open_lm/norms.py b/open_lm/norms.py index f02f2e48..41fa4a6f 100644 --- a/open_lm/norms.py +++ b/open_lm/norms.py @@ -8,6 +8,16 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +# Adding flag if using TE FP8 +using_te = False +try: + import transformer_engine.pytorch as te + from transformer_engine.common import recipe + fp8_format = recipe.Format.HYBRID + fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + using_te = True +except ImportError as ie: + using_te = False class LayerNorm(nn.Module): # NOTE: taken from official pytorch implementation and modified @@ -55,7 +65,14 @@ def reset_parameters(self) -> None: self.bias.zero_() def forward(self, input: Tensor) -> Tensor: - return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + if using_te: + layer_norm_module = te.LayerNorm(self.normalized_shape, eps=self.eps, device='cuda', params_dtype=input.dtype) + output_tensor = layer_norm_module(input) + if self.weight is not None and self.bias is not None: + output_tensor = output_tensor * self.weight + self.bias + return output_tensor + else: + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) def extra_repr(self) -> str: return ( @@ -77,13 +94,20 @@ def forward(self, x): downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias with torch.autocast(enabled=False, device_type=module_device.type): - return F.layer_norm( - downcast_x, - self.normalized_shape, - downcast_weight, - downcast_bias, - self.eps, - ) + if using_te: + layer_norm_module = te.LayerNorm(self.normalized_shape, eps=self.eps, device='cuda', params_dtype=downcast_x.dtype) + output_tensor = layer_norm_module(downcast_x) + if downcast_weight is not None and downcast_bias is not None: + output_tensor = output_tensor * downcast_weight + downcast_bias + return output_tensor + else: + return F.layer_norm( + downcast_x, + self.normalized_shape, + downcast_weight, + downcast_bias, + self.eps, + ) def _cast_if_autocast_enabled(tensor): diff --git a/open_lm/params.py b/open_lm/params.py index 9d5efa7c..1d06e283 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -742,6 +742,17 @@ def parse_args(args): action="store_true", help="If set, allow model to do multiple data passes over our dataset, in order to reach the desired number of tokens.", ) + parser.add_argument( + "--use-smp-flash-attention", + type=int, + default=None, + help="Using SMP Flash Attention.", + ) + parser.add_argument( + "--sharding-strategy", + default=None, + help="Sharding Strategy", + ) add_model_args(parser) diff --git a/open_lm/train.py b/open_lm/train.py index 4ff5a27e..0ac6f2a7 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -26,6 +26,16 @@ from open_lm.precision import get_autocast from open_lm.meters import AverageMeter +# Adding flag if using TE FP8 +using_te = False +try: + import transformer_engine.pytorch as te + from transformer_engine.common import recipe + fp8_format = recipe.Format.HYBRID + fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + using_te = True +except ImportError as ie: + using_te = False def unwrap_model(model): if hasattr(model, "module"): @@ -41,7 +51,7 @@ def backward(total_loss, scaler): total_loss.backward() -def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None): +def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, all_gpus=None): """Trains model for one epoch on the provided data. Returns: @@ -113,19 +123,36 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler optimizer.zero_grad() if args.accum_freq == 1: - with autocast(): - inputs, targets = sample_chunk(texts, args) - out, _, _ = model(inputs) + if using_te: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus): + inputs, targets = sample_chunk(texts, args) + + out, _, _ = model(inputs) + + if args.log_logit_mean: + logit_m.update(torch.mean(out).item()) + + total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) + total_loss = total_lm_loss + if args.moe_freq > 0: + total_load_balancing_loss = batched_load_balancing_loss(moe_args) + clear_load_balancing_loss() + total_loss += total_load_balancing_loss + else: + with autocast(): + inputs, targets = sample_chunk(texts, args) - if args.log_logit_mean: - logit_m.update(torch.mean(out).item()) + out, _, _ = model(inputs) - total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) - total_loss = total_lm_loss - if args.moe_freq > 0: - total_load_balancing_loss = batched_load_balancing_loss(moe_args) - clear_load_balancing_loss() - total_loss += total_load_balancing_loss + if args.log_logit_mean: + logit_m.update(torch.mean(out).item()) + + total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) + total_loss = total_lm_loss + if args.moe_freq > 0: + total_load_balancing_loss = batched_load_balancing_loss(moe_args) + clear_load_balancing_loss() + total_loss += total_load_balancing_loss backward(total_loss, scaler) else: @@ -147,7 +174,13 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler if inputs_ii.shape[0] == 0: break targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] - out, _, _ = model(inputs_ii) + + if using_te: + ## TODO + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus): + out, _, _ = model(inputs_ii) + else: + out, _, _ = model(inputs_ii) if args.log_logit_mean: logit_m.update(torch.mean(out).item()) From e8cad2ad76ae7305b406f9bf36080a4e799f8487 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 21 Feb 2024 12:10:56 -0800 Subject: [PATCH 02/22] Linter changes --- open_lm/main.py | 4 ++-- open_lm/model.py | 11 ++++++----- open_lm/norms.py | 10 ++++++++-- open_lm/train.py | 10 +++++++--- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 22c1165a..52cb69b4 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -438,7 +438,7 @@ def main(args): random_seed(args.seed, args.rank) - all_gpus = dist.new_group(backend='nccl') + all_gpus = dist.new_group(backend="nccl") if args.distributed: if args.fsdp: @@ -757,7 +757,7 @@ def main(args): total_steps=total_steps, args=args, tb_writer=writer, - all_gpus=all_gpus + all_gpus=all_gpus, ) if args.distributed: diff --git a/open_lm/model.py b/open_lm/model.py index d33a4482..1c6d6b29 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -40,6 +40,7 @@ try: import transformer_engine.pytorch as te from transformer_engine.common import recipe + fp8_format = recipe.Format.HYBRID fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") using_te = True @@ -129,8 +130,8 @@ def __init__(self, layer_id, args: Params): self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads if using_te: - self.in_proj = te.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device='cuda') - self.out_proj = te.Linear(args.n_heads * self.head_dim, args.dim, bias=False, device='cuda') + self.in_proj = te.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device="cuda") + self.out_proj = te.Linear(args.n_heads * self.head_dim, args.dim, bias=False, device="cuda") else: self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) @@ -139,7 +140,7 @@ def __init__(self, layer_id, args: Params): self.apply_qk_norm = args.apply_qk_norm if using_te: - # initialize norm layers for queries and keys if needed + # initialize norm layers for queries and keys if needed self.q_norm = ( te.LayerNorm( args.n_heads * self.head_dim, @@ -245,8 +246,8 @@ def __init__(self, layer_id, args: Params): # Follows mosaic mpt7b, but without a bias. self.hidden_dim = args.dim * 4 if using_te: - self._ff_w1 = te.Linear(args.dim, self.hidden_dim, bias=False, device='cuda') - self._ff_w2 = te.Linear(self.hidden_dim, self.dim, bias=False, device='cuda') + self._ff_w1 = te.Linear(args.dim, self.hidden_dim, bias=False, device="cuda") + self._ff_w2 = te.Linear(self.hidden_dim, self.dim, bias=False, device="cuda") self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) else: self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) diff --git a/open_lm/norms.py b/open_lm/norms.py index 41fa4a6f..6847ab5b 100644 --- a/open_lm/norms.py +++ b/open_lm/norms.py @@ -13,12 +13,14 @@ try: import transformer_engine.pytorch as te from transformer_engine.common import recipe + fp8_format = recipe.Format.HYBRID fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") using_te = True except ImportError as ie: using_te = False + class LayerNorm(nn.Module): # NOTE: taken from official pytorch implementation and modified # to allow revoval of gain and bias independently @@ -66,7 +68,9 @@ def reset_parameters(self) -> None: def forward(self, input: Tensor) -> Tensor: if using_te: - layer_norm_module = te.LayerNorm(self.normalized_shape, eps=self.eps, device='cuda', params_dtype=input.dtype) + layer_norm_module = te.LayerNorm( + self.normalized_shape, eps=self.eps, device="cuda", params_dtype=input.dtype + ) output_tensor = layer_norm_module(input) if self.weight is not None and self.bias is not None: output_tensor = output_tensor * self.weight + self.bias @@ -95,7 +99,9 @@ def forward(self, x): downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias with torch.autocast(enabled=False, device_type=module_device.type): if using_te: - layer_norm_module = te.LayerNorm(self.normalized_shape, eps=self.eps, device='cuda', params_dtype=downcast_x.dtype) + layer_norm_module = te.LayerNorm( + self.normalized_shape, eps=self.eps, device="cuda", params_dtype=downcast_x.dtype + ) output_tensor = layer_norm_module(downcast_x) if downcast_weight is not None and downcast_bias is not None: output_tensor = output_tensor * downcast_weight + downcast_bias diff --git a/open_lm/train.py b/open_lm/train.py index 0ac6f2a7..d9ae777b 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -31,12 +31,14 @@ try: import transformer_engine.pytorch as te from transformer_engine.common import recipe + fp8_format = recipe.Format.HYBRID fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") using_te = True except ImportError as ie: using_te = False + def unwrap_model(model): if hasattr(model, "module"): return model.module @@ -51,7 +53,9 @@ def backward(total_loss, scaler): total_loss.backward() -def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, all_gpus=None): +def train_one_epoch( + model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, all_gpus=None +): """Trains model for one epoch on the provided data. Returns: @@ -126,7 +130,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler if using_te: with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus): inputs, targets = sample_chunk(texts, args) - + out, _, _ = model(inputs) if args.log_logit_mean: @@ -137,7 +141,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler if args.moe_freq > 0: total_load_balancing_loss = batched_load_balancing_loss(moe_args) clear_load_balancing_loss() - total_loss += total_load_balancing_loss + total_loss += total_load_balancing_loss else: with autocast(): inputs, targets = sample_chunk(texts, args) From 05140874815ae94adcdf88930652c62a2755e8a1 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Thu, 29 Feb 2024 15:59:43 -0800 Subject: [PATCH 03/22] Converting all Linears to TE Linears except output Linear --- open_lm/model.py | 157 +++++++++++++++++++--------------------------- open_lm/params.py | 11 ---- 2 files changed, 63 insertions(+), 105 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 1c6d6b29..aaa82c76 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -98,7 +98,7 @@ class Params: seq_len: int = 2048 post_embed_norm: bool = False weight_tying: bool = False - norm_type: nn.Module = nn.LayerNorm + norm_type: nn.Module = te.LayerNorm if using_te else nn.LayerNorm attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn apply_qk_norm: bool = False moe_loss_weight: float = 0.1 @@ -129,72 +129,41 @@ def __init__(self, layer_id, args: Params): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads - if using_te: - self.in_proj = te.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device="cuda") - self.out_proj = te.Linear(args.n_heads * self.head_dim, args.dim, bias=False, device="cuda") - else: - self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) - self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) + self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func self.apply_qk_norm = args.apply_qk_norm - if using_te: - # initialize norm layers for queries and keys if needed - self.q_norm = ( - te.LayerNorm( - args.n_heads * self.head_dim, - eps=args.norm_eps, - ) - if self.apply_qk_norm - else nn.Identity() - ) - self.k_norm = ( - te.LayerNorm( - args.n_heads * self.head_dim, - eps=args.norm_eps, - ) - if self.apply_qk_norm - else nn.Identity() - ) - else: - # initialize norm layers for queries and keys if needed - self.q_norm = ( - args.norm_type( - args.n_heads * self.head_dim, - eps=args.norm_eps, - ) - if self.apply_qk_norm - else nn.Identity() + # initialize norm layers for queries and keys if needed + self.q_norm = ( + args.norm_type( + args.n_heads * self.head_dim, + eps=args.norm_eps, ) - self.k_norm = ( - args.norm_type( - args.n_heads * self.head_dim, - eps=args.norm_eps, - ) - if self.apply_qk_norm - else nn.Identity() + if self.apply_qk_norm + else nn.Identity() + ) + self.k_norm = ( + args.norm_type( + args.n_heads * self.head_dim, + eps=args.norm_eps, ) + if self.apply_qk_norm + else nn.Identity() + ) self.layer_id = layer_id self.dim = args.dim self.reset_parameters() def reset_parameters(self): - if using_te: - # initialize weights by trunc_normal(1/sqrt(fan_in)) - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self.in_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) - # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. - std = std / math.sqrt(2 * (self.layer_id + 1)) - torch.nn.init.trunc_normal_(self.out_proj.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) - else: - # initialize weights by trunc_normal(1/sqrt(fan_in)) - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std) - # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. - std = std / math.sqrt(2 * (self.layer_id + 1)) - torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) + # initialize weights by trunc_normal(1/sqrt(fan_in)) + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std) + # scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better. + std = std / math.sqrt(2 * (self.layer_id + 1)) + torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False): batchsize, q_len, _ = x.shape @@ -245,14 +214,9 @@ def __init__(self, layer_id, args: Params): elif args.ffn_type == "gelu": # Follows mosaic mpt7b, but without a bias. self.hidden_dim = args.dim * 4 - if using_te: - self._ff_w1 = te.Linear(args.dim, self.hidden_dim, bias=False, device="cuda") - self._ff_w2 = te.Linear(self.hidden_dim, self.dim, bias=False, device="cuda") - self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) - else: - self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) - self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) - self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) + self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) + self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) + self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) elif args.ffn_type == "moe": moe_args = MoEArgs( hidden_size=args.dim, @@ -270,24 +234,14 @@ def __init__(self, layer_id, args: Params): self.feed_forward = MoE(moe_args) self.layer_id = layer_id - if using_te: - self.attention_norm = te.LayerNorm( - args.dim, - eps=args.norm_eps, - ) - self.ffn_norm = te.LayerNorm( - args.dim, - eps=args.norm_eps, - ) - else: - self.attention_norm = args.norm_type( - args.dim, - eps=args.norm_eps, - ) - self.ffn_norm = args.norm_type( - args.dim, - eps=args.norm_eps, - ) + self.attention_norm = args.norm_type( + args.dim, + eps=args.norm_eps, + ) + self.ffn_norm = args.norm_type( + args.dim, + eps=args.norm_eps, + ) self.attention.seq_len = args.seq_len self.reset_parameters() @@ -301,20 +255,12 @@ def reset_parameters(self): std = std / math.sqrt(2 * (self.layer_id + 1)) torch.nn.init.trunc_normal_(self.feed_forward.w3.weight, std=std, a=-3 * std, b=3 * std) elif self._ffn_type == "gelu": - if using_te: - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self._ff_w1.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) - - std = 1.0 / math.sqrt(self.hidden_dim) - std = std / math.sqrt(2 * (self._layer_id + 1)) - torch.nn.init.trunc_normal_(self._ff_w2.weight_tensor.float(), std=std, a=-3 * std, b=3 * std) - else: - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std) + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std) - std = 1.0 / math.sqrt(self.hidden_dim) - std = std / math.sqrt(2 * (self._layer_id + 1)) - torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) + std = 1.0 / math.sqrt(self.hidden_dim) + std = std / math.sqrt(2 * (self._layer_id + 1)) + torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x, past_key_value=None, use_cache=False): h, past_key_value = self.attention( @@ -487,10 +433,33 @@ def forward(self, x): return out, None, None +def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=['output'], copy_weights=True): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + nn_linear_to_te_linear(module, include_modules, exclude_modules, copy_weights) + + if isinstance(module, torch.nn.Linear) and name in include_modules and name not in exclude_modules: + old_module = model._modules[name] + model._modules[name] = te.Linear( + module.in_features, + module.out_features, + module.bias is not None, + device = 'cuda' + ) + if copy_weights: + model._modules[name].weight_tensor.data.copy_(old_module.weight.data) + if model._modules[name].bias is not None and old_module.bias is not None: + model._modules[name].bias.data.copy_(old_module.bias) + return model + def create_model(args): if "mamba" in args.model: model = Mamba(create_params(args)) + if using_te: + model = nn_linear_to_te_linear(model) return model else: model = Transformer(create_params(args)) + if using_te: + model = nn_linear_to_te_linear(model) return model diff --git a/open_lm/params.py b/open_lm/params.py index 1d06e283..9d5efa7c 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -742,17 +742,6 @@ def parse_args(args): action="store_true", help="If set, allow model to do multiple data passes over our dataset, in order to reach the desired number of tokens.", ) - parser.add_argument( - "--use-smp-flash-attention", - type=int, - default=None, - help="Using SMP Flash Attention.", - ) - parser.add_argument( - "--sharding-strategy", - default=None, - help="Sharding Strategy", - ) add_model_args(parser) From ff8e8c88eb66c34a0297da89bbc913520452a3e7 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Thu, 29 Feb 2024 16:02:36 -0800 Subject: [PATCH 04/22] Fix linter errors --- open_lm/model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index aaa82c76..4bd0f410 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -433,7 +433,7 @@ def forward(self, x): return out, None, None -def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=['output'], copy_weights=True): +def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=["output"], copy_weights=True): for name, module in model.named_children(): if len(list(module.children())) > 0: nn_linear_to_te_linear(module, include_modules, exclude_modules, copy_weights) @@ -441,10 +441,7 @@ def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=['output'] if isinstance(module, torch.nn.Linear) and name in include_modules and name not in exclude_modules: old_module = model._modules[name] model._modules[name] = te.Linear( - module.in_features, - module.out_features, - module.bias is not None, - device = 'cuda' + module.in_features, module.out_features, module.bias is not None, device="cuda" ) if copy_weights: model._modules[name].weight_tensor.data.copy_(old_module.weight.data) @@ -452,6 +449,7 @@ def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=['output'] model._modules[name].bias.data.copy_(old_module.bias) return model + def create_model(args): if "mamba" in args.model: model = Mamba(create_params(args)) From 9224b0eff17c42761ba21f6d9928f36b6ce7cdc0 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Thu, 11 Apr 2024 23:59:23 -0700 Subject: [PATCH 05/22] Rebase from main and update FP8 changes --- open_lm/attention.py | 75 +++++++++++++++++++++++++-------- open_lm/main.py | 2 + open_lm/model.py | 8 ++-- open_lm/norms.py | 16 ++++--- open_lm/params.py | 10 ++++- open_lm/train.py | 8 ++-- sagemaker_train/cfg_sample.yaml | 10 ++--- 7 files changed, 91 insertions(+), 38 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index f134786c..a4960adc 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -4,6 +4,17 @@ from torch.nn import functional as F import xformers.ops as xops +# Adding flag if using TE FP8 +using_te = False +try: + import transformer_engine.pytorch as te + from transformer_engine.common import recipe + + fp8_format = recipe.Format.HYBRID + fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + using_te = True +except ImportError as ie: + using_te = False def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype): """ @@ -81,27 +92,57 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None): # Same as above, we would like to use: # mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device) mask = get_rectangular_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) - return ( - F.scaled_dot_product_attention( - queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask + if using_te: + scaleddotproductattn_module = te.DotProductAttention(queries.size(-1), 1) + return ( + scaleddotproductattn_module( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask + ) + .transpose(1, 2) + .contiguous() + ) + else: + return ( + F.scaled_dot_product_attention( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask + ) + .transpose(1, 2) + .contiguous() ) - .transpose(1, 2) - .contiguous() - ) elif queries.shape[1] == 1: - return ( - F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)) - .transpose(1, 2) - .contiguous() - ) + if using_te: + scaleddotproductattn_module = te.DotProductAttention(queries.size(-1), 1) + return ( + scaleddotproductattn_module( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask + ) + .transpose(1, 2) + .contiguous() + ) + else: + return ( + F.scaled_dot_product_attention(queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)) + .transpose(1, 2) + .contiguous() + ) else: - return ( - F.scaled_dot_product_attention( - queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal + if using_te: + scaleddotproductattn_module = te.DotProductAttention(queries.size(-1), 1) + return ( + scaleddotproductattn_module( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attn_mask=mask + ) + .transpose(1, 2) + .contiguous() + ) + else: + return ( + F.scaled_dot_product_attention( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), is_causal=is_causal + ) + .transpose(1, 2) + .contiguous() ) - .transpose(1, 2) - .contiguous() - ) ATTN_ACTIVATIONS = { diff --git a/open_lm/main.py b/open_lm/main.py index fb9b5ab0..fe3558da 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -463,6 +463,8 @@ def main(args): if args.grad_checkpointing: model.set_grad_checkpointing() + + all_gpus = dist.new_group(backend="nccl") if args.distributed: if args.fsdp: diff --git a/open_lm/model.py b/open_lm/model.py index f19f9351..b3d096fc 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -43,7 +43,7 @@ from transformer_engine.common import recipe fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") using_te = True except ImportError as ie: using_te = False @@ -462,7 +462,7 @@ def create_params(args): vocab_size=cfg["vocab_size"], post_embed_norm=cfg["post_embed_norm"], weight_tying=cfg["weight_tying"], - norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)), + norm_type=get_norm_class(cfg.get("model_norm", args.model_norm), args.use_fp8), attn_func=get_attn_func( args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha ), @@ -522,11 +522,11 @@ def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=["output"] def create_model(args): if "mamba" in args.model: model = Mamba(create_params(args)) - if using_te: + if args.use_fp8: model = nn_linear_to_te_linear(model) return model else: model = Transformer(create_params(args)) - if using_te: + if args.use_fp8: model = nn_linear_to_te_linear(model) return model diff --git a/open_lm/norms.py b/open_lm/norms.py index 6847ab5b..2084c6ba 100644 --- a/open_lm/norms.py +++ b/open_lm/norms.py @@ -15,7 +15,7 @@ from transformer_engine.common import recipe fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") using_te = True except ImportError as ie: using_te = False @@ -33,6 +33,7 @@ def __init__( elementwise_bias: bool = True, device=None, dtype=None, + use_fp8: bool = False ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -44,6 +45,7 @@ def __init__( self.eps = eps self.elementwise_gain = elementwise_gain self.elementwise_bias = elementwise_bias + self.use_fp8 = use_fp8 if self.elementwise_gain: self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -67,7 +69,7 @@ def reset_parameters(self) -> None: self.bias.zero_() def forward(self, input: Tensor) -> Tensor: - if using_te: + if using_te and self.use_fp8: layer_norm_module = te.LayerNorm( self.normalized_shape, eps=self.eps, device="cuda", params_dtype=input.dtype ) @@ -98,7 +100,7 @@ def forward(self, x): downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias with torch.autocast(enabled=False, device_type=module_device.type): - if using_te: + if using_te and self.use_fp8: layer_norm_module = te.LayerNorm( self.normalized_shape, eps=self.eps, device="cuda", params_dtype=downcast_x.dtype ) @@ -163,18 +165,18 @@ def extra_repr(self) -> str: return "{normalized_shape}, eps={eps} ".format(**self.__dict__) -def get_norm_class(model_norm): +def get_norm_class(model_norm, use_fp8=False): if model_norm == "default_layer_norm": return torch.nn.LayerNorm elif model_norm == "lp_layer_norm": return LPLayerNorm elif model_norm == "gain_only_lp_layer_norm": - return partial(LPLayerNorm, elementwise_gain=True, elementwise_bias=False) + return partial(LPLayerNorm, elementwise_gain=True, elementwise_bias=False, use_fp8=use_fp8) elif model_norm == "gain_only_layer_norm": - return partial(LayerNorm, elementwise_gain=True, elementwise_bias=False) + return partial(LayerNorm, elementwise_gain=True, elementwise_bias=False, use_fp8=use_fp8) elif model_norm == "no_wb_layer_norm": - return partial(LayerNorm, elementwise_gain=False, elementwise_bias=False) + return partial(LayerNorm, elementwise_gain=False, elementwise_bias=False, use_fp8=use_fp8) elif model_norm == "rms_norm": return RmsNorm diff --git a/open_lm/params.py b/open_lm/params.py index eea63106..e9075dc6 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -477,7 +477,7 @@ def parse_args(args): ) parser.add_argument( "--precision", - choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], + choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32", "amp_fp8"], default="amp", help="Floating point precision.", ) @@ -761,6 +761,14 @@ def parse_args(args): default=0, help="Whether to log the average model training loss. if not 0, it will log the average loss over the specified number of steps.", ) + + parser.add_argument( + "--use-fp8", + action="store_true", + default=False, + help="If set, allow FP8 training for the model.", + ) + add_model_args(parser) config = maybe_load_config(parser, args) diff --git a/open_lm/train.py b/open_lm/train.py index 488068f4..89440a58 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -33,7 +33,7 @@ from transformer_engine.common import recipe fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") using_te = True except ImportError as ie: using_te = False @@ -54,7 +54,7 @@ def backward(total_loss, scaler): def train_one_epoch( - model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None + model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None, all_gpus=None ): """Trains model for one epoch on the provided data. @@ -131,7 +131,7 @@ def train_one_epoch( optimizer.zero_grad() if args.accum_freq == 1: - if using_te: + if using_te and args.use_fp8: with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus): inputs, targets = sample_chunk(texts, args) @@ -190,7 +190,7 @@ def train_one_epoch( break targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] - if using_te: + if using_te and args.use_fp8: ## TODO with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus): out, _, _ = model(inputs_ii) diff --git a/sagemaker_train/cfg_sample.yaml b/sagemaker_train/cfg_sample.yaml index 07158730..815d0910 100644 --- a/sagemaker_train/cfg_sample.yaml +++ b/sagemaker_train/cfg_sample.yaml @@ -5,7 +5,7 @@ data-key: "json" dataset-resampled: True # delete-previous-checkpoint: False # Total 25B * 40 = 1T tokens -epochs: 40 +epochs: 2 fsdp: True fsdp-limit-all-gathers: True # grad-checkpointing: False @@ -13,19 +13,19 @@ grad-clip-norm: 1 log-every-n-steps: 20 model: "open_lm_7b" name: "sample_7b" -precision: "amp_bfloat16" +precision: "amp_fp8" report-to: "wandb" seed: 124 train-data-mix-weights: [0.725, 0.275] train-data: ["TODO"] -train-num-samples: 25_000_000_000 +train-num-samples: 28_000_000_000 wandb-project-name: "lm1" workers: 4 logs: /opt/ml/checkpoints/ # Some important parameters, double checked with Mitchell: -batch-size: 16 -ffn-type: swiglu +batch-size: 128 +ffn-type: gemma_geglu # fsdp-amp: False fsdp-pure-bf16: True fsdp-backward-prefetch: True From 4563be2f0d2b2ffaadc0cc73d5f34f4ff7637d0c Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Fri, 12 Apr 2024 04:16:34 -0700 Subject: [PATCH 06/22] Linter changes --- open_lm/attention.py | 1 + open_lm/main.py | 2 +- open_lm/norms.py | 2 +- open_lm/params.py | 2 +- open_lm/train.py | 14 +++++++++++++- 5 files changed, 17 insertions(+), 4 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index a4960adc..ecc4b61a 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -16,6 +16,7 @@ except ImportError as ie: using_te = False + def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype): """ >>> get_rectangular_mask((1, 1), 2, 2, "cpu", torch.float32) diff --git a/open_lm/main.py b/open_lm/main.py index fe3558da..ac5409c9 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -463,7 +463,7 @@ def main(args): if args.grad_checkpointing: model.set_grad_checkpointing() - + all_gpus = dist.new_group(backend="nccl") if args.distributed: diff --git a/open_lm/norms.py b/open_lm/norms.py index 2084c6ba..598a444d 100644 --- a/open_lm/norms.py +++ b/open_lm/norms.py @@ -33,7 +33,7 @@ def __init__( elementwise_bias: bool = True, device=None, dtype=None, - use_fp8: bool = False + use_fp8: bool = False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() diff --git a/open_lm/params.py b/open_lm/params.py index e9075dc6..33c9abe7 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -768,7 +768,7 @@ def parse_args(args): default=False, help="If set, allow FP8 training for the model.", ) - + add_model_args(parser) config = maybe_load_config(parser, args) diff --git a/open_lm/train.py b/open_lm/train.py index 89440a58..2230769a 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -54,7 +54,19 @@ def backward(total_loss, scaler): def train_one_epoch( - model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None, all_gpus=None + model, + data, + loss, + epoch, + step, + optimizer, + scaler, + scheduler, + total_steps, + args, + tb_writer=None, + averagers=None, + all_gpus=None, ): """Trains model for one epoch on the provided data. From 937927a293943117a90402d325b7bcd188b97fab Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Sat, 13 Apr 2024 01:12:23 -0700 Subject: [PATCH 07/22] Adding asserts for FP8 --- open_lm/main.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/open_lm/main.py b/open_lm/main.py index ac5409c9..ade85a20 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -100,6 +100,20 @@ def get_state_dict(name): return sd +def assert_fp8(model): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + assert_fp8(module) + if isinstance(module, torch.nn.Linear): + logging.warning(f"Module {name} is nn.Linear and not converted to TE FP8 equivalent of Linear.") + if isinstance(module, torch.nn.LayerNorm): + logging.warning(f"Module {name} is nn.LayerNorm and not converted to TE FP8 equivalent of LayerNorm.") + if isinstance(module, torch.nn.functional.scaled_dot_product_attention): + logging.warning( + f"Module {name} is torch.nn.functional.scaled_dot_product_attention and not converted to TE FP8 equivalent of DotProductAttention." + ) + + def load_model(args, model, different_seed=False): checkpoint = pt_load(args.resume, map_location="cpu") if "epoch" in checkpoint: @@ -447,6 +461,10 @@ def main(args): with torch.device("meta" if args.experimental_meta_device and args.fsdp else args.device): model = create_model(args) + if args.use_fp8: + logging.info("Using FP8 to run training.") + assert_fp8(model) + args.vocab_size = model.vocab_size args.seq_len = model.seq_len if args.train_num_samples is not None: From 1594b9f4c393356170a260454962c723e68f4360 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Sat, 13 Apr 2024 02:44:42 -0700 Subject: [PATCH 08/22] Asserts for FP8 --- open_lm/main.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index ade85a20..39025029 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -108,10 +108,6 @@ def assert_fp8(model): logging.warning(f"Module {name} is nn.Linear and not converted to TE FP8 equivalent of Linear.") if isinstance(module, torch.nn.LayerNorm): logging.warning(f"Module {name} is nn.LayerNorm and not converted to TE FP8 equivalent of LayerNorm.") - if isinstance(module, torch.nn.functional.scaled_dot_product_attention): - logging.warning( - f"Module {name} is torch.nn.functional.scaled_dot_product_attention and not converted to TE FP8 equivalent of DotProductAttention." - ) def load_model(args, model, different_seed=False): From 740f2b187fbdec70d10dcf2c63d71347f5426039 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Sat, 13 Apr 2024 03:50:16 -0700 Subject: [PATCH 09/22] Predefine all_gpus for TE --- open_lm/main.py | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 39025029..d02f9c1d 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -478,7 +478,9 @@ def main(args): if args.grad_checkpointing: model.set_grad_checkpointing() - all_gpus = dist.new_group(backend="nccl") + all_gpus = None + if args.use_fp8: + all_gpus = dist.new_group(backend="nccl") if args.distributed: if args.fsdp: @@ -537,18 +539,31 @@ def main(args): # Initialize FSDP. Use the same seed across workers to ensure reset_parameters is the same across workers. random_seed(args.seed, rank=0) - model = FSDP( - model, - process_group=all_gpus, - auto_wrap_policy=transformer_auto_wrapper_policy, - device_id=device, - mixed_precision=mp_policy, - cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), - use_orig_params=args.fsdp_use_orig_params, - limit_all_gathers=args.fsdp_limit_all_gathers, - sync_module_states=True, - **fsdp_kwargs, - ) + + if args.use_fp8: + model = FSDP( + model, + process_group=all_gpus, + auto_wrap_policy=transformer_auto_wrapper_policy, + device_id=device, + mixed_precision=mp_policy, + cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), + use_orig_params=args.fsdp_use_orig_params, + limit_all_gathers=args.fsdp_limit_all_gathers, + sync_module_states=True, + **fsdp_kwargs, + ) + else: + model = FSDP( + model, + auto_wrap_policy=transformer_auto_wrapper_policy, + device_id=device, + mixed_precision=mp_policy, + cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), + use_orig_params=args.fsdp_use_orig_params, + limit_all_gathers=args.fsdp_limit_all_gathers, + **fsdp_kwargs, + ) print(f"After FSDP parameter num: {sum(p.numel() for p in model.parameters()):,} on rank {args.rank}") print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}") From 3713f61b4428b894020e85af6a23b3f004e2c471 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 17 Apr 2024 09:51:12 -0700 Subject: [PATCH 10/22] Remove if/else for fp8 checks --- open_lm/attention.py | 49 +++++++++++++++++++++++++++++ open_lm/model.py | 15 ++++++++- open_lm/norms.py | 75 ++++++++++++++++++++++++++------------------ 3 files changed, 107 insertions(+), 32 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index e9c9ba5c..786ee50a 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -149,6 +149,55 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None): ) +def torch_attn_te(queries, keys, values, is_causal, attention_mask=None): + _, _, num_q_heads, _ = queries.shape + _, _, num_k_heads, _ = keys.shape + scaleddotproductattn_module = te.DotProductAttention(num_attention_heads=num_q_heads, kv_channels=num_k_heads) + if is_causal and keys.shape[1] > queries.shape[1] > 1: + q_seq_len = queries.shape[1] + k_seq_len = keys.shape[1] + # Same as above, we would like to use: + # mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask().materialize((1, 1, q_seq_len, k_seq_len), queries.dtype, queries.device) + mask = get_rectangular_causal_mask((1, 1), q_seq_len, k_seq_len, queries.device, queries.dtype) + if attention_mask is not None: + apply_attention_mask_(mask, attention_mask, queries_dtype=queries.dtype) + return ( + scaleddotproductattn_module( + queries.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2), attention_mask=mask + ) + .transpose(1, 2) + .contiguous() + ) + else: + if attention_mask is None: + bias = None + # If we only have one query, assume we don't need to be in causal mode (can attend to all keys). + if queries.shape == 1: + is_causal = False + else: + if not is_causal: + raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.") + # Build causal mask that assumes queries are in the end of the sequence. + batch, q_seq_len, heads, _ = queries.shape + k_seq_len = keys.shape[1] + bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype) + if attention_mask is not None: + apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype) + # We apply causal mask in attention instead of using is_causal=True. + is_causal = False + return ( + scaleddotproductattn_module( + queries.transpose(1, 2), + keys.transpose(1, 2), + values.transpose(1, 2), + attention_mask=bias, + attn_mask_type="causal" if is_causal else None, + ) + .transpose(1, 2) + .contiguous() + ) + + ATTN_ACTIVATIONS = { "relu": F.relu, "relu_squared": lambda x: torch.pow(F.relu(x), 2), diff --git a/open_lm/model.py b/open_lm/model.py index 76bdceb9..79d614ab 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -5,6 +5,7 @@ from pathlib import Path from dataclasses import dataclass from typing import Callable +import inspect import torch import torch.nn.functional as F @@ -513,7 +514,7 @@ def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=["output"] if len(list(module.children())) > 0: nn_linear_to_te_linear(module, include_modules, exclude_modules, copy_weights) - if isinstance(module, torch.nn.Linear) and name in include_modules and name not in exclude_modules: + if isinstance(module, torch.nn.Linear) and name not in exclude_modules: old_module = model._modules[name] model._modules[name] = te.Linear( module.in_features, module.out_features, module.bias is not None, device="cuda" @@ -522,6 +523,18 @@ def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=["output"] model._modules[name].weight_tensor.data.copy_(old_module.weight.data) if model._modules[name].bias is not None and old_module.bias is not None: model._modules[name].bias.data.copy_(old_module.bias) + elif isinstance(module, torch.nn.LayerNorm) and name not in exclude_modules: + logging.warning(f"[FP8] Module {name} is nn.LayerNorm and not converted to TE FP8 equivalent of LayerNorm.") + elif isinstance(module, torch.nn.Module) and name not in exclude_modules: + source_code = inspect.getsource(module.forward) + if "F.scaled_dot_product_attention" in source_code: + logging.warning( + f"[FP8] F.scaled_dot_product_attention -> te.DotProductAttention is not implemented yet for {name}." + ) + if "F.layer_norm" in source_code: + logging.warning( + f"[FP8] Module {name} is F.layer_norm and not converted to TE FP8 equivalent te.LayerNorm." + ) return model diff --git a/open_lm/norms.py b/open_lm/norms.py index 598a444d..ec92e77d 100644 --- a/open_lm/norms.py +++ b/open_lm/norms.py @@ -33,7 +33,6 @@ def __init__( elementwise_bias: bool = True, device=None, dtype=None, - use_fp8: bool = False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -45,7 +44,6 @@ def __init__( self.eps = eps self.elementwise_gain = elementwise_gain self.elementwise_bias = elementwise_bias - self.use_fp8 = use_fp8 if self.elementwise_gain: self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) @@ -69,16 +67,7 @@ def reset_parameters(self) -> None: self.bias.zero_() def forward(self, input: Tensor) -> Tensor: - if using_te and self.use_fp8: - layer_norm_module = te.LayerNorm( - self.normalized_shape, eps=self.eps, device="cuda", params_dtype=input.dtype - ) - output_tensor = layer_norm_module(input) - if self.weight is not None and self.bias is not None: - output_tensor = output_tensor * self.weight + self.bias - return output_tensor - else: - return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) def extra_repr(self) -> str: return ( @@ -100,22 +89,13 @@ def forward(self, x): downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias with torch.autocast(enabled=False, device_type=module_device.type): - if using_te and self.use_fp8: - layer_norm_module = te.LayerNorm( - self.normalized_shape, eps=self.eps, device="cuda", params_dtype=downcast_x.dtype - ) - output_tensor = layer_norm_module(downcast_x) - if downcast_weight is not None and downcast_bias is not None: - output_tensor = output_tensor * downcast_weight + downcast_bias - return output_tensor - else: - return F.layer_norm( - downcast_x, - self.normalized_shape, - downcast_weight, - downcast_bias, - self.eps, - ) + return F.layer_norm( + downcast_x, + self.normalized_shape, + downcast_weight, + downcast_bias, + self.eps, + ) def _cast_if_autocast_enabled(tensor): @@ -130,6 +110,31 @@ def _cast_if_autocast_enabled(tensor): return tensor +class LayerNormTE(LayerNorm): + def forward(self, x): + layer_norm_module = te.LayerNorm(self.normalized_shape, eps=self.eps, device="cuda", params_dtype=x.dtype) + output_tensor = layer_norm_module(x) + if self.weight is not None and self.bias is not None: + output_tensor = output_tensor * self.weight + self.bias + return output_tensor + + +class LPLayerNormTE(LayerNorm): + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + with torch.autocast(enabled=False, device_type=module_device.type): + layer_norm_module = te.LayerNorm( + self.normalized_shape, eps=self.eps, device="cuda", params_dtype=downcast_x.dtype + ) + output_tensor = layer_norm_module(downcast_x) + if downcast_weight is not None and downcast_bias is not None: + output_tensor = output_tensor * downcast_weight + downcast_bias + return output_tensor + + class RmsNorm(nn.Module): def __init__( self, @@ -169,14 +174,22 @@ def get_norm_class(model_norm, use_fp8=False): if model_norm == "default_layer_norm": return torch.nn.LayerNorm elif model_norm == "lp_layer_norm": + if use_fp8 and using_te: + return LPLayerNormTE return LPLayerNorm elif model_norm == "gain_only_lp_layer_norm": - return partial(LPLayerNorm, elementwise_gain=True, elementwise_bias=False, use_fp8=use_fp8) + if use_fp8 and using_te: + return partial(LPLayerNormTE, elementwise_gain=True, elementwise_bias=False) + return partial(LPLayerNorm, elementwise_gain=True, elementwise_bias=False) elif model_norm == "gain_only_layer_norm": - return partial(LayerNorm, elementwise_gain=True, elementwise_bias=False, use_fp8=use_fp8) + if use_fp8 and using_te: + return partial(LayerNormTE, elementwise_gain=True, elementwise_bias=False) + return partial(LayerNorm, elementwise_gain=True, elementwise_bias=False) elif model_norm == "no_wb_layer_norm": - return partial(LayerNorm, elementwise_gain=False, elementwise_bias=False, use_fp8=use_fp8) + if use_fp8 and using_te: + return partial(LayerNormTE, elementwise_gain=False, elementwise_bias=False) + return partial(LayerNorm, elementwise_gain=False, elementwise_bias=False) elif model_norm == "rms_norm": return RmsNorm From 14e82789bafdf52a536135a9c950bb6d1347c2c9 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 17 Apr 2024 10:20:58 -0700 Subject: [PATCH 11/22] Remove extra asserts --- open_lm/main.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 0e548679..4f191f5d 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -100,16 +100,6 @@ def get_state_dict(name): return sd -def assert_fp8(model): - for name, module in model.named_children(): - if len(list(module.children())) > 0: - assert_fp8(module) - if isinstance(module, torch.nn.Linear): - logging.warning(f"Module {name} is nn.Linear and not converted to TE FP8 equivalent of Linear.") - if isinstance(module, torch.nn.LayerNorm): - logging.warning(f"Module {name} is nn.LayerNorm and not converted to TE FP8 equivalent of LayerNorm.") - - def load_model(args, model, different_seed=False): checkpoint = pt_load(args.resume, map_location="cpu") if "epoch" in checkpoint: @@ -463,7 +453,6 @@ def main(args): if args.use_fp8: logging.info("Using FP8 to run training.") - assert_fp8(model) args.vocab_size = model.vocab_size args.seq_len = model.seq_len From e572510d9ddc10c4bac23ef8b6462665d9edb1fb Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 17 Apr 2024 10:50:55 -0700 Subject: [PATCH 12/22] Removing unused deps --- open_lm/attention.py | 3 --- open_lm/model.py | 3 --- open_lm/norms.py | 3 --- 3 files changed, 9 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index 786ee50a..1cd92a1b 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -8,10 +8,7 @@ using_te = False try: import transformer_engine.pytorch as te - from transformer_engine.common import recipe - fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") using_te = True except ImportError as ie: using_te = False diff --git a/open_lm/model.py b/open_lm/model.py index 79d614ab..6c33af38 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -41,10 +41,7 @@ using_te = False try: import transformer_engine.pytorch as te - from transformer_engine.common import recipe - fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") using_te = True except ImportError as ie: using_te = False diff --git a/open_lm/norms.py b/open_lm/norms.py index ec92e77d..38fb90b6 100644 --- a/open_lm/norms.py +++ b/open_lm/norms.py @@ -12,10 +12,7 @@ using_te = False try: import transformer_engine.pytorch as te - from transformer_engine.common import recipe - fp8_format = recipe.Format.HYBRID - fp8_recipe = recipe.DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") using_te = True except ImportError as ie: using_te = False From 8350cb92517412a0bec3f54dc602e566c75fc3b4 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 17 Apr 2024 10:53:53 -0700 Subject: [PATCH 13/22] Update routine for converting NN layers to TE equivalents --- open_lm/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 6c33af38..94294580 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -506,10 +506,10 @@ def forward(self, x): return out, None, None -def nn_linear_to_te_linear(model, include_modules=[], exclude_modules=["output"], copy_weights=True): +def torch_NN_to_TE(model, include_modules=[], exclude_modules=["output"], copy_weights=True): for name, module in model.named_children(): if len(list(module.children())) > 0: - nn_linear_to_te_linear(module, include_modules, exclude_modules, copy_weights) + torch_NN_to_TE(module, include_modules, exclude_modules, copy_weights) if isinstance(module, torch.nn.Linear) and name not in exclude_modules: old_module = model._modules[name] @@ -539,10 +539,10 @@ def create_model(args): if "mamba" in args.model: model = Mamba(create_params(args)) if args.use_fp8: - model = nn_linear_to_te_linear(model) + model = torch_NN_to_TE(model) return model else: model = Transformer(create_params(args)) if args.use_fp8: - model = nn_linear_to_te_linear(model) + model = torch_NN_to_TE(model) return model From 4e582a0944287f84f08f098342bd7a79b3bdf393 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 24 Apr 2024 02:43:45 -0700 Subject: [PATCH 14/22] Update FP8 flags and checks for layers --- open_lm/main.py | 41 +++++++++++------------------------ open_lm/model.py | 30 +++++++++++++++----------- open_lm/train.py | 56 +++++++++++++++++------------------------------- 3 files changed, 50 insertions(+), 77 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 4f191f5d..572944fc 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -451,7 +451,9 @@ def main(args): with torch.device("meta" if args.experimental_meta_device and args.fsdp else args.device): model = create_model(args) + all_gpus = None if args.use_fp8: + all_gpus = dist.new_group(backend="nccl") logging.info("Using FP8 to run training.") args.vocab_size = model.vocab_size @@ -471,10 +473,6 @@ def main(args): if args.grad_checkpointing: model.set_grad_checkpointing() - all_gpus = None - if args.use_fp8: - all_gpus = dist.new_group(backend="nccl") - if args.distributed: if args.fsdp: transformer_layer_cls = None @@ -533,30 +531,17 @@ def main(args): # Initialize FSDP. Use the same seed across workers to ensure reset_parameters is the same across workers. random_seed(args.seed, rank=0) - if args.use_fp8: - model = FSDP( - model, - process_group=all_gpus, - auto_wrap_policy=transformer_auto_wrapper_policy, - device_id=device, - mixed_precision=mp_policy, - cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), - use_orig_params=args.fsdp_use_orig_params, - limit_all_gathers=args.fsdp_limit_all_gathers, - sync_module_states=True, - **fsdp_kwargs, - ) - else: - model = FSDP( - model, - auto_wrap_policy=transformer_auto_wrapper_policy, - device_id=device, - mixed_precision=mp_policy, - cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), - use_orig_params=args.fsdp_use_orig_params, - limit_all_gathers=args.fsdp_limit_all_gathers, - **fsdp_kwargs, - ) + model = FSDP( + model, + process_group=all_gpus, + auto_wrap_policy=transformer_auto_wrapper_policy, + device_id=device, + mixed_precision=mp_policy, + cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), + use_orig_params=args.fsdp_use_orig_params, + limit_all_gathers=args.fsdp_limit_all_gathers, + **fsdp_kwargs, + ) print(f"After FSDP parameter num: {sum(p.numel() for p in model.parameters()):,} on rank {args.rank}") print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}") diff --git a/open_lm/model.py b/open_lm/model.py index 94294580..71f181ac 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -98,6 +98,8 @@ class Params: post_embed_norm: bool = False weight_tying: bool = False norm_type: nn.Module = te.LayerNorm if using_te else nn.LayerNorm + linear_type: nn.Linear = te.Linear if using_te else nn.Linear + linear_device: str = 'cuda' if using_te else None attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn apply_qk_norm: bool = False moe_loss_weight: float = 0.1 @@ -130,8 +132,8 @@ def __init__(self, layer_id, args: Params): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads - self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) - self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.in_proj = args.linear_type(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device=args.linear_device) + self.out_proj = args.linear_type(args.n_heads * self.head_dim, args.dim, bias=False, device=args.linear_device) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func self.apply_qk_norm = args.apply_qk_norm @@ -206,13 +208,13 @@ class GemmaMLP(nn.Module): Modified from https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L182-L201 """ - def __init__(self, dim: int, hidden_dim: int, layer_id: int): + def __init__(self, dim: int, hidden_dim: int, layer_id: int, args: Params): super().__init__() self.dim = dim self.hidden_dim = hidden_dim - self.gate_proj = nn.Linear(dim, hidden_dim) - self.up_proj = nn.Linear(dim, hidden_dim) - self.down_proj = nn.Linear(hidden_dim, dim) + self.gate_proj = args.linear_type(dim, hidden_dim, device=args.linear_device) + self.up_proj = args.linear_type(dim, hidden_dim, device=args.linear_device) + self.down_proj = args.linear_type(hidden_dim, dim, device=args.linear_device) self._layer_id = layer_id def forward(self, x): @@ -236,10 +238,10 @@ def reset_parameters(self): # Same as pseudocode provided from xformers SwiGLU # https://github.com/facebookresearch/xformers class SwiGLUTorch(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim, bias=True): + def __init__(self, in_dim, hidden_dim, out_dim, args: Params, bias=True): super().__init__() - self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias) - self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias) + self.w12 = args.linear_type(in_dim, 2 * hidden_dim, bias=bias, device=args.linear_device) + self.w3 = args.linear_type(hidden_dim, out_dim, bias=bias, device=args.linear_device) def forward(self, x): gate, x = self.w12(x).chunk(2, dim=-1) @@ -263,17 +265,17 @@ def __init__(self, layer_id, args: Params): elif args.ffn_type == "swiglu_torch": # this follows llama / lit llama -- go to multiple of 256 self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) - self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, bias=False) + self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, args, bias=False) elif args.ffn_type == "gelu": # Follows mosaic mpt7b, but without a bias. self.hidden_dim = args.dim * 4 - self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) - self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) + self._ff_w1 = args.linear_type(args.dim, self.hidden_dim, bias=False, device=args.linear_device) + self._ff_w2 = args.linear_type(self.hidden_dim, args.dim, bias=False, device=args.linear_device) self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) elif args.ffn_type == "gemma_geglu": # this follows llama / lit llama -- go to multiple of 256 self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) - self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id) + self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id, args) elif args.ffn_type == "moe": moe_args = MoEArgs( hidden_size=args.dim, @@ -467,6 +469,8 @@ def create_params(args): post_embed_norm=cfg["post_embed_norm"], weight_tying=cfg["weight_tying"], norm_type=get_norm_class(cfg.get("model_norm", args.model_norm), args.use_fp8), + linear_type=te.Linear if (using_te and args.use_fp8) else nn.Linear, + linear_device='cuda' if (using_te and args.use_fp8) else None, attn_func=get_attn_func( args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha ), diff --git a/open_lm/train.py b/open_lm/train.py index 970aba2c..dd901ffe 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -142,41 +142,30 @@ def train_one_epoch( data_time_m.update(time.time() - end) optimizer.zero_grad() - if args.accum_freq == 1: - if using_te and args.use_fp8: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus): - inputs, targets = sample_chunk(texts, args) - - out, _, _ = model(inputs) - - if args.log_logit_mean: - logit_m.update(torch.mean(out).item()) + if using_te and args.use_fp8: + autocast_func = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) + else: + autocast_func = autocast() - total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) - total_loss = total_lm_loss - if args.moe_freq > 0: - total_load_balancing_loss = batched_load_balancing_loss(moe_args) - clear_load_balancing_loss() - total_loss += total_load_balancing_loss - else: - with autocast(): - inputs, targets = sample_chunk(texts, args) + if args.accum_freq == 1: + with autocast_func: + inputs, targets = sample_chunk(texts, args) - out, _, _ = model(inputs) + out, _, _ = model(inputs) - if args.log_logit_mean: - logit_m.update(torch.mean(out).item()) + if args.log_logit_mean: + logit_m.update(torch.mean(out).item()) - total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) - total_loss = total_lm_loss - if args.moe_freq > 0: - total_load_balancing_loss = batched_load_balancing_loss(moe_args) - clear_load_balancing_loss() - total_loss += total_load_balancing_loss + total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1)) + total_loss = total_lm_loss + if args.moe_freq > 0: + total_load_balancing_loss = batched_load_balancing_loss(moe_args) + clear_load_balancing_loss() + total_loss += total_load_balancing_loss backward(total_loss, scaler) if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: - with autocast(): + with autocast_func: for key, averager in averagers.avgs_dict.items(): with torch.no_grad(): out_avg, _, _ = averager.av_model(inputs) @@ -196,18 +185,13 @@ def train_one_epoch( if isinstance(model, FSDP) and ii != args.accum_freq - 1: maybe_no_sync = model.no_sync with maybe_no_sync(): - with autocast(): + with autocast_func: inputs_ii = inputs[ii * per_batch : (ii + 1) * per_batch] if inputs_ii.shape[0] == 0: break targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] - if using_te and args.use_fp8: - ## TODO - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus): - out, _, _ = model(inputs_ii) - else: - out, _, _ = model(inputs_ii) + out, _, _ = model(inputs_ii) if args.log_logit_mean: logit_m.update(torch.mean(out).item()) @@ -224,7 +208,7 @@ def train_one_epoch( local_loss += local_load_balancing_loss backward(local_loss, scaler) - with autocast(): + with autocast_func: if ( averagers is not None and args.log_avg_model_training_loss From cdb0cf76ecd08be449d5792e5ea6c79f8cc478c9 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 24 Apr 2024 02:45:51 -0700 Subject: [PATCH 15/22] Linter checks --- open_lm/model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 71f181ac..6df797ea 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -99,7 +99,7 @@ class Params: weight_tying: bool = False norm_type: nn.Module = te.LayerNorm if using_te else nn.LayerNorm linear_type: nn.Linear = te.Linear if using_te else nn.Linear - linear_device: str = 'cuda' if using_te else None + linear_device: str = "cuda" if using_te else None attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn apply_qk_norm: bool = False moe_loss_weight: float = 0.1 @@ -132,7 +132,9 @@ def __init__(self, layer_id, args: Params): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads - self.in_proj = args.linear_type(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device=args.linear_device) + self.in_proj = args.linear_type( + args.dim, 3 * args.n_heads * self.head_dim, bias=False, device=args.linear_device + ) self.out_proj = args.linear_type(args.n_heads * self.head_dim, args.dim, bias=False, device=args.linear_device) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func @@ -470,7 +472,7 @@ def create_params(args): weight_tying=cfg["weight_tying"], norm_type=get_norm_class(cfg.get("model_norm", args.model_norm), args.use_fp8), linear_type=te.Linear if (using_te and args.use_fp8) else nn.Linear, - linear_device='cuda' if (using_te and args.use_fp8) else None, + linear_device="cuda" if (using_te and args.use_fp8) else None, attn_func=get_attn_func( args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha ), From afb46cb16a6329c57ee56cac1629d8ee122ff7d0 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 24 Apr 2024 03:17:38 -0700 Subject: [PATCH 16/22] Add checks for autocast function --- open_lm/model.py | 34 ---------------------------------- open_lm/train.py | 21 ++++++++++++--------- 2 files changed, 12 insertions(+), 43 deletions(-) diff --git a/open_lm/model.py b/open_lm/model.py index 6df797ea..63a160ca 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -5,7 +5,6 @@ from pathlib import Path from dataclasses import dataclass from typing import Callable -import inspect import torch import torch.nn.functional as F @@ -512,43 +511,10 @@ def forward(self, x): return out, None, None -def torch_NN_to_TE(model, include_modules=[], exclude_modules=["output"], copy_weights=True): - for name, module in model.named_children(): - if len(list(module.children())) > 0: - torch_NN_to_TE(module, include_modules, exclude_modules, copy_weights) - - if isinstance(module, torch.nn.Linear) and name not in exclude_modules: - old_module = model._modules[name] - model._modules[name] = te.Linear( - module.in_features, module.out_features, module.bias is not None, device="cuda" - ) - if copy_weights: - model._modules[name].weight_tensor.data.copy_(old_module.weight.data) - if model._modules[name].bias is not None and old_module.bias is not None: - model._modules[name].bias.data.copy_(old_module.bias) - elif isinstance(module, torch.nn.LayerNorm) and name not in exclude_modules: - logging.warning(f"[FP8] Module {name} is nn.LayerNorm and not converted to TE FP8 equivalent of LayerNorm.") - elif isinstance(module, torch.nn.Module) and name not in exclude_modules: - source_code = inspect.getsource(module.forward) - if "F.scaled_dot_product_attention" in source_code: - logging.warning( - f"[FP8] F.scaled_dot_product_attention -> te.DotProductAttention is not implemented yet for {name}." - ) - if "F.layer_norm" in source_code: - logging.warning( - f"[FP8] Module {name} is F.layer_norm and not converted to TE FP8 equivalent te.LayerNorm." - ) - return model - - def create_model(args): if "mamba" in args.model: model = Mamba(create_params(args)) - if args.use_fp8: - model = torch_NN_to_TE(model) return model else: model = Transformer(create_params(args)) - if args.use_fp8: - model = torch_NN_to_TE(model) return model diff --git a/open_lm/train.py b/open_lm/train.py index dd901ffe..f4dcb85b 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -142,13 +142,10 @@ def train_one_epoch( data_time_m.update(time.time() - end) optimizer.zero_grad() - if using_te and args.use_fp8: - autocast_func = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) - else: - autocast_func = autocast() - if args.accum_freq == 1: - with autocast_func: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) if ( + using_te and args.use_fp8 + ) else autocast(): inputs, targets = sample_chunk(texts, args) out, _, _ = model(inputs) @@ -165,7 +162,9 @@ def train_one_epoch( backward(total_loss, scaler) if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: - with autocast_func: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) if ( + using_te and args.use_fp8 + ) else autocast(): for key, averager in averagers.avgs_dict.items(): with torch.no_grad(): out_avg, _, _ = averager.av_model(inputs) @@ -185,7 +184,9 @@ def train_one_epoch( if isinstance(model, FSDP) and ii != args.accum_freq - 1: maybe_no_sync = model.no_sync with maybe_no_sync(): - with autocast_func: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) if ( + using_te and args.use_fp8 + ) else autocast(): inputs_ii = inputs[ii * per_batch : (ii + 1) * per_batch] if inputs_ii.shape[0] == 0: break @@ -208,7 +209,9 @@ def train_one_epoch( local_loss += local_load_balancing_loss backward(local_loss, scaler) - with autocast_func: + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) if ( + using_te and args.use_fp8 + ) else autocast(): if ( averagers is not None and args.log_avg_model_training_loss From 00c9e5bf9e10bfa278dca0b1c7db97c03b3b549c Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 24 Apr 2024 03:44:29 -0700 Subject: [PATCH 17/22] Minor edit to model --- open_lm/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/model.py b/open_lm/model.py index 63a160ca..4d598fa1 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -97,7 +97,7 @@ class Params: post_embed_norm: bool = False weight_tying: bool = False norm_type: nn.Module = te.LayerNorm if using_te else nn.LayerNorm - linear_type: nn.Linear = te.Linear if using_te else nn.Linear + linear_type: nn.Module = te.Linear if using_te else nn.Linear linear_device: str = "cuda" if using_te else None attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn apply_qk_norm: bool = False From 40c7a6d0c088304cb075fdc08071fd9b75912171 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 24 Apr 2024 11:36:44 -0700 Subject: [PATCH 18/22] Adding default args as Params to SwiGLUTorch --- open_lm/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/model.py b/open_lm/model.py index 4d598fa1..8fd13649 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -239,7 +239,7 @@ def reset_parameters(self): # Same as pseudocode provided from xformers SwiGLU # https://github.com/facebookresearch/xformers class SwiGLUTorch(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim, args: Params, bias=True): + def __init__(self, in_dim, hidden_dim, out_dim, args: Params=Params, bias=True): super().__init__() self.w12 = args.linear_type(in_dim, 2 * hidden_dim, bias=bias, device=args.linear_device) self.w3 = args.linear_type(hidden_dim, out_dim, bias=bias, device=args.linear_device) From 8dbd1d87b0a031db39b189ba2c1ff7d6e32f40d7 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Wed, 24 Apr 2024 13:20:23 -0700 Subject: [PATCH 19/22] Linter fixes --- open_lm/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/model.py b/open_lm/model.py index 8fd13649..d58f7887 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -239,7 +239,7 @@ def reset_parameters(self): # Same as pseudocode provided from xformers SwiGLU # https://github.com/facebookresearch/xformers class SwiGLUTorch(nn.Module): - def __init__(self, in_dim, hidden_dim, out_dim, args: Params=Params, bias=True): + def __init__(self, in_dim, hidden_dim, out_dim, args: Params = Params, bias=True): super().__init__() self.w12 = args.linear_type(in_dim, 2 * hidden_dim, bias=bias, device=args.linear_device) self.w3 = args.linear_type(hidden_dim, out_dim, bias=bias, device=args.linear_device) From ec917464d0f847010dd617cd6539df91501ade40 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Tue, 30 Apr 2024 12:33:36 -0700 Subject: [PATCH 20/22] Adding Torch Attention TE --- open_lm/attention.py | 9 +++------ open_lm/model.py | 6 +++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index 1cd92a1b..f22197ab 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -247,12 +247,7 @@ def custom_attn( return torch.einsum("bhqk,bkhd->bqhd", attn_weight, values) -def get_attn_func( - attn_name, - attn_activation=None, - attn_seq_scalar=None, - alpha=None, -): +def get_attn_func(attn_name, attn_activation=None, attn_seq_scalar=None, alpha=None, use_fp8=False): if attn_name == "auto": return xformers_attn if torch.cuda.is_available() else torch_attn elif attn_name == "xformers_attn": @@ -264,6 +259,8 @@ def get_attn_func( # call .contiguous() on the output tensor. [#188] return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() elif attn_name == "torch_attn": + if using_te and use_fp8: + return torch_attn_te return torch_attn elif attn_name == "custom_attn": assert ( diff --git a/open_lm/model.py b/open_lm/model.py index d58f7887..e458977b 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -473,7 +473,11 @@ def create_params(args): linear_type=te.Linear if (using_te and args.use_fp8) else nn.Linear, linear_device="cuda" if (using_te and args.use_fp8) else None, attn_func=get_attn_func( - args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha + args.attn_name, + args.attn_activation, + args.attn_seq_scalar, + args.attn_seq_scalar_alpha, + use_fp8=args.use_fp8, ), apply_qk_norm=cfg.get("qk_norm", args.qk_norm), positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type), From 29000f366d68f192fef8861b98b0919e7cf9e712 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Mon, 6 May 2024 10:42:40 -0700 Subject: [PATCH 21/22] Fixing FP8+FSDP memory issues by removing FP8 from all activations until natively supports FSDP --- open_lm/attention.py | 10 ++++---- open_lm/distributed.py | 7 ++++-- open_lm/main.py | 16 ++++++------ open_lm/model.py | 56 ++++++++++++++++++++++++++++++------------ open_lm/norms.py | 7 +++--- open_lm/train.py | 10 ++++---- 6 files changed, 66 insertions(+), 40 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index f22197ab..4c55c3bf 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -147,9 +147,9 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None): def torch_attn_te(queries, keys, values, is_causal, attention_mask=None): - _, _, num_q_heads, _ = queries.shape - _, _, num_k_heads, _ = keys.shape - scaleddotproductattn_module = te.DotProductAttention(num_attention_heads=num_q_heads, kv_channels=num_k_heads) + _, num_q_heads, _, _ = queries.shape + _, _, hidden_dim_k, _ = values.shape + scaleddotproductattn_module = te.DotProductAttention(num_attention_heads=num_q_heads, kv_channels=hidden_dim_k) if is_causal and keys.shape[1] > queries.shape[1] > 1: q_seq_len = queries.shape[1] k_seq_len = keys.shape[1] @@ -259,8 +259,8 @@ def get_attn_func(attn_name, attn_activation=None, attn_seq_scalar=None, alpha=N # call .contiguous() on the output tensor. [#188] return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous() elif attn_name == "torch_attn": - if using_te and use_fp8: - return torch_attn_te + # if using_te and use_fp8: + # return torch_attn_te return torch_attn elif attn_name == "custom_attn": assert ( diff --git a/open_lm/distributed.py b/open_lm/distributed.py index 8c07d663..95e22542 100644 --- a/open_lm/distributed.py +++ b/open_lm/distributed.py @@ -57,6 +57,7 @@ def init_distributed_device(args): args.world_size = 1 args.rank = 0 # global rank args.local_rank = 0 + args.world_group = None # For testing, allow forcing distributed mode to test distributed code path even on one gpu. if is_using_distributed() or args.force_distributed: if "SLURM_PROCID" in os.environ: @@ -74,7 +75,7 @@ def init_distributed_device(args): os.environ["LOCAL_RANK"] = str(args.local_rank) os.environ["RANK"] = str(args.rank) os.environ["WORLD_SIZE"] = str(args.world_size) - torch.distributed.init_process_group( + args.world_group = torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, @@ -85,7 +86,9 @@ def init_distributed_device(args): # Note that this currently assumes that the world size is all gpus in a node. assert args.preset_world_size is None, "--preset_world_size with torchrun is not currently supported." args.local_rank, _, _ = world_info_from_env() - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url) + args.world_group = torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url + ) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() args.distributed = True diff --git a/open_lm/main.py b/open_lm/main.py index a3dad1ee..2cdb2b9d 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -443,18 +443,18 @@ def main(args): random_seed(args.seed, 0) + tensor_parallel_group = None + if args.use_fp8: + tensor_parallel_group = torch.distributed.new_group(ranks=[0], backend="nccl") + logging.info("Using FP8 to run training.") + model = None if args.hf_model is not None: model = create_wrapped_hf_model(args) else: # Optional: Use meta device with torch.device("meta" if args.experimental_meta_device and args.fsdp else args.device): - model = create_model(args) - - all_gpus = None - if args.use_fp8: - all_gpus = dist.new_group(backend="nccl") - logging.info("Using FP8 to run training.") + model = create_model(args, tensor_parallel_group) args.vocab_size = model.vocab_size args.seq_len = model.seq_len @@ -533,7 +533,7 @@ def main(args): model = FSDP( model, - process_group=all_gpus, + process_group=args.world_group, auto_wrap_policy=transformer_auto_wrapper_policy, device_id=device, mixed_precision=mp_policy, @@ -815,7 +815,7 @@ def main(args): total_steps=total_steps, args=args, tb_writer=writer, - all_gpus=all_gpus, + data_parallel_group=args.world_group, ) if args.distributed: diff --git a/open_lm/model.py b/open_lm/model.py index c3f09b2e..ec21a070 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -37,10 +37,19 @@ # Adding flag if using TE FP8 using_te = False +LinearTE = nn.Linear try: import transformer_engine.pytorch as te using_te = True + + class LinearTE(te.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, inp: torch.Tensor, is_first_microbatch: bool = True): + return super().forward(inp, is_first_microbatch=True) + except ImportError as ie: using_te = False @@ -96,8 +105,8 @@ class Params: post_embed_norm: bool = False weight_tying: bool = False norm_type: nn.Module = te.LayerNorm if using_te else nn.LayerNorm - linear_type: nn.Module = te.Linear if using_te else nn.Linear - linear_device: str = "cuda" if using_te else None + linear_type: nn.Module = LinearTE if using_te else nn.Linear + te_device: str = "cuda" if using_te else None attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn apply_qk_norm: bool = False moe_loss_weight: float = 0.1 @@ -130,10 +139,8 @@ def __init__(self, layer_id, args: Params): super().__init__() self.n_heads = args.n_heads self.head_dim = args.dim // args.n_heads - self.in_proj = args.linear_type( - args.dim, 3 * args.n_heads * self.head_dim, bias=False, device=args.linear_device - ) - self.out_proj = args.linear_type(args.n_heads * self.head_dim, args.dim, bias=False, device=args.linear_device) + self.in_proj = args.linear_type(args.dim, 3 * args.n_heads * self.head_dim, bias=False, device=args.te_device) + self.out_proj = args.linear_type(args.n_heads * self.head_dim, args.dim, bias=False, device=args.te_device) self.pos_embed = get_pos_embed(args) self.attn_fn = args.attn_func self.apply_qk_norm = args.apply_qk_norm @@ -143,6 +150,7 @@ def __init__(self, layer_id, args: Params): args.norm_type( args.n_heads * self.head_dim, eps=args.norm_eps, + device=args.te_device, ) if self.apply_qk_norm else nn.Identity() @@ -151,6 +159,7 @@ def __init__(self, layer_id, args: Params): args.norm_type( args.n_heads * self.head_dim, eps=args.norm_eps, + device=args.te_device, ) if self.apply_qk_norm else nn.Identity() @@ -212,9 +221,9 @@ def __init__(self, dim: int, hidden_dim: int, layer_id: int, args: Params): super().__init__() self.dim = dim self.hidden_dim = hidden_dim - self.gate_proj = args.linear_type(dim, hidden_dim, device=args.linear_device) - self.up_proj = args.linear_type(dim, hidden_dim, device=args.linear_device) - self.down_proj = args.linear_type(hidden_dim, dim, device=args.linear_device) + self.gate_proj = nn.Linear(dim, hidden_dim, device=args.te_device) + self.up_proj = nn.Linear(dim, hidden_dim, device=args.te_device) + self.down_proj = nn.Linear(hidden_dim, dim, device=args.te_device) self._layer_id = layer_id def forward(self, x): @@ -240,8 +249,8 @@ def reset_parameters(self): class SwiGLUTorch(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, args: Params = Params, bias=True): super().__init__() - self.w12 = args.linear_type(in_dim, 2 * hidden_dim, bias=bias, device=args.linear_device) - self.w3 = args.linear_type(hidden_dim, out_dim, bias=bias, device=args.linear_device) + self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias, device=args.te_device) + self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias, device=args.te_device) def forward(self, x): gate, x = self.w12(x).chunk(2, dim=-1) @@ -269,8 +278,8 @@ def __init__(self, layer_id, args: Params): elif args.ffn_type == "gelu": # Follows mosaic mpt7b, but without a bias. self.hidden_dim = args.dim * 4 - self._ff_w1 = args.linear_type(args.dim, self.hidden_dim, bias=False, device=args.linear_device) - self._ff_w2 = args.linear_type(self.hidden_dim, args.dim, bias=False, device=args.linear_device) + self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False, device=args.te_device) + self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False, device=args.te_device) self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) elif args.ffn_type == "gemma_geglu": # this follows llama / lit llama -- go to multiple of 256 @@ -296,10 +305,12 @@ def __init__(self, layer_id, args: Params): self.attention_norm = args.norm_type( args.dim, eps=args.norm_eps, + device=args.te_device, ) self.ffn_norm = args.norm_type( args.dim, eps=args.norm_eps, + device=args.te_device, ) self.attention.seq_len = args.seq_len self.reset_parameters() @@ -469,8 +480,8 @@ def create_params(args): post_embed_norm=cfg["post_embed_norm"], weight_tying=cfg["weight_tying"], norm_type=get_norm_class(cfg.get("model_norm", args.model_norm), args.use_fp8), - linear_type=te.Linear if (using_te and args.use_fp8) else nn.Linear, - linear_device="cuda" if (using_te and args.use_fp8) else None, + linear_type=LinearTE if (using_te and args.use_fp8) else nn.Linear, + te_device="cuda" if (using_te and args.use_fp8) else None, attn_func=get_attn_func( args.attn_name, args.attn_activation, @@ -514,10 +525,23 @@ def forward(self, x): return out, None, None -def create_model(args): +def te_linear_ops(model, exclude_modules=["output"], tensor_parallel_group=None): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + te_linear_ops(module, exclude_modules, tensor_parallel_group) + if isinstance(module, te.Linear): + model._modules[name].set_tensor_parallel_group(tensor_parallel_group) + return model + + +def create_model(args, tensor_parallel_group=None): if "mamba" in args.model: model = Mamba(create_params(args)) + if tensor_parallel_group is not None and using_te: + model = te_linear_ops(model.to(torch.bfloat16).cuda(), tensor_parallel_group) return model else: model = Transformer(create_params(args)) + if tensor_parallel_group is not None and using_te: + model = te_linear_ops(model.to(torch.bfloat16).cuda(), tensor_parallel_group) return model diff --git a/open_lm/norms.py b/open_lm/norms.py index 38fb90b6..fc07993e 100644 --- a/open_lm/norms.py +++ b/open_lm/norms.py @@ -122,10 +122,9 @@ def forward(self, x): downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias - with torch.autocast(enabled=False, device_type=module_device.type): - layer_norm_module = te.LayerNorm( - self.normalized_shape, eps=self.eps, device="cuda", params_dtype=downcast_x.dtype - ) + layer_norm_module = te.LayerNorm( + self.normalized_shape, eps=self.eps, device="cuda", params_dtype=downcast_x.dtype + ) output_tensor = layer_norm_module(downcast_x) if downcast_weight is not None and downcast_bias is not None: output_tensor = output_tensor * downcast_weight + downcast_bias diff --git a/open_lm/train.py b/open_lm/train.py index 3ccdd645..b7ae6650 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -68,7 +68,7 @@ def train_one_epoch( args, tb_writer=None, averagers=None, - all_gpus=None, + data_parallel_group=None, ): """Trains model for one epoch on the provided data. @@ -145,7 +145,7 @@ def train_one_epoch( optimizer.zero_grad() if args.accum_freq == 1: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) if ( + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group) if ( using_te and args.use_fp8 ) else autocast(): inputs, targets = sample_chunk(texts, args) @@ -164,7 +164,7 @@ def train_one_epoch( backward(total_loss, scaler) if averagers is not None and args.log_avg_model_training_loss and i % args.log_avg_model_training_loss == 0: - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) if ( + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group) if ( using_te and args.use_fp8 ) else autocast(): for key, averager in averagers.avgs_dict.items(): @@ -186,7 +186,7 @@ def train_one_epoch( if isinstance(model, FSDP) and ii != args.accum_freq - 1: maybe_no_sync = model.no_sync with maybe_no_sync(): - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) if ( + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group) if ( using_te and args.use_fp8 ) else autocast(): inputs_ii = inputs[ii * per_batch : (ii + 1) * per_batch] @@ -211,7 +211,7 @@ def train_one_epoch( local_loss += local_load_balancing_loss backward(local_loss, scaler) - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=all_gpus) if ( + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group) if ( using_te and args.use_fp8 ) else autocast(): if ( From aca1b75de20851486733878f3aabb658a2f9b4f9 Mon Sep 17 00:00:00 2001 From: Romil Shah Date: Tue, 18 Jun 2024 20:54:12 -0700 Subject: [PATCH 22/22] Updating deps and config --- sagemaker_train/Dockerfile | 4 +++- sagemaker_train/Dockerfile_update | 2 ++ sagemaker_train/cfg_sample.yaml | 26 +++++++++++++++----------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/sagemaker_train/Dockerfile b/sagemaker_train/Dockerfile index 36300449..f4d11c62 100644 --- a/sagemaker_train/Dockerfile +++ b/sagemaker_train/Dockerfile @@ -1,7 +1,7 @@ ARG AWS_REGION # SageMaker PyTorch image -FROM 763104351884.dkr.ecr.${AWS_REGION}.amazonaws.com/pytorch-training:2.1.0-gpu-py310-cu121-ubuntu20.04-sagemaker +FROM 763104351884.dkr.ecr.${AWS_REGION}.amazonaws.com/pytorch-training:2.2.0-gpu-py310-cu121-ubuntu20.04-sagemaker # Run custom installation of libraries # RUN pip install xxx @@ -26,6 +26,8 @@ RUN rm /opt/ml/code/setup.py RUN pip install -r /opt/ml/code/requirements.txt RUN pip uninstall flash-attn -y RUN pip install flash-attn>=2.2 +RUN pip install s3fs>=2023.6.0 +RUN pip install --upgrade s3fs # # Prevent sagemaker from installing requirements again. # RUN rm /opt/ml/code/setup.py RUN rm /opt/ml/code/requirements.txt diff --git a/sagemaker_train/Dockerfile_update b/sagemaker_train/Dockerfile_update index 1282688c..c9fa936b 100644 --- a/sagemaker_train/Dockerfile_update +++ b/sagemaker_train/Dockerfile_update @@ -9,6 +9,8 @@ COPY . /opt/ml/code/ # RUN pip install -e /opt/ml/code/ # Prevent sagemaker from installing requirements again. +RUN pip install s3fs>=2023.6.0 +RUN pip install --upgrade s3fs RUN rm /opt/ml/code/setup.py RUN rm /opt/ml/code/requirements.txt diff --git a/sagemaker_train/cfg_sample.yaml b/sagemaker_train/cfg_sample.yaml index 815d0910..59b03277 100644 --- a/sagemaker_train/cfg_sample.yaml +++ b/sagemaker_train/cfg_sample.yaml @@ -1,11 +1,11 @@ accum-freq: 4 beta1: 0.9 beta2: 0.95 -data-key: "json" -dataset-resampled: True +data-key: "json.gz" +dataset-resampled: False # delete-previous-checkpoint: False # Total 25B * 40 = 1T tokens -epochs: 2 +epochs: 1 fsdp: True fsdp-limit-all-gathers: True # grad-checkpointing: False @@ -13,26 +13,30 @@ grad-clip-norm: 1 log-every-n-steps: 20 model: "open_lm_7b" name: "sample_7b" -precision: "amp_fp8" +precision: "amp_bfloat16" report-to: "wandb" seed: 124 -train-data-mix-weights: [0.725, 0.275] -train-data: ["TODO"] +# train-data-mix-weights: [0.725, 0.275] +dataset-manifest: ["TODO"] train-num-samples: 28_000_000_000 -wandb-project-name: "lm1" +wandb-project-name: "lm7" workers: 4 logs: /opt/ml/checkpoints/ # Some important parameters, double checked with Mitchell: -batch-size: 128 -ffn-type: gemma_geglu +global-batch-size: 32 +ffn-type: swiglu_torch # fsdp-amp: False fsdp-pure-bf16: True fsdp-backward-prefetch: True -lr: 3.e-4 +fsdp-use-orig-params: True +lr: 3.e-3 lr-cooldown-end: 3.e-5 model-norm: "gain_only_lp_layer_norm" qk-norm: True -warmup: 5000 +warmup: 2000 wd: 0.1 z-loss-coefficient: 1.e-4 +attn-name: torch_attn +torchcompile: True +use_fp8: False