diff --git a/gpt_conf.py b/gpt_conf.py index 39ee976422..b5abda10f9 100644 --- a/gpt_conf.py +++ b/gpt_conf.py @@ -191,6 +191,7 @@ class GPTConfig: activation_end: str = "relu" activation_transition_start_iter: int = 0 activation_transition_end_iter: int = None + use_offchip_peri_ln: bool = False # MLP Options use_parallel_mlp: bool = False diff --git a/train_args.py b/train_args.py index d12d5f24da..46121920de 100644 --- a/train_args.py +++ b/train_args.py @@ -873,6 +873,7 @@ def parse_args(): model_group.add_argument("--activation_end", type=str, default="relu", choices=activation_variations) model_group.add_argument("--activation_transition_start_iter", type=int, default=0) model_group.add_argument("--activation_transition_end_iter", type=int, default=None, help="If None, defaults to max_iters from training config.") + model_group.add_argument('--use_offchip_peri_ln', default=False, action=argparse.BooleanOptionalAction, help="apply after combining the residual off chip") # Quantization model_group.add_argument("--full_quant_iteration", type=int, default=None, diff --git a/variations/block_variations.py b/variations/block_variations.py index fd485b92f7..90166f93dc 100644 --- a/variations/block_variations.py +++ b/variations/block_variations.py @@ -232,6 +232,11 @@ def edgellm_asic_forward(block, x: torch.Tensor, iter_num: int) -> torch.Tensor: # Therefore subtract initial before merging # x = (chip_output - x_quantized_residual_initial) + x adj_chip_output = chip_output - x_quantized_residual_initial + + # Off-Chip Peri-LN + if getattr(block, "use_offchip_peri_ln", False): + x = block.offchip_peri_ln(x) + x = block._combine_resid("mlp", x, adj_chip_output) if block.quantization_dict["quantize_asic_offchip_residual"]: @@ -313,6 +318,10 @@ def _setup_norms_sequential(self, config, norm_cls) -> None: if getattr(self, "use_post_ln_mlp", False): self.post_ln_mlp = norm_cls(config) + # Off-chip Peri-LN (EdgeLLM ASIC only) + if getattr(self, "use_offchip_peri_ln", False): + self.offchip_peri_ln = norm_cls(config) + normalization_setup_variations = { "parallel_mlp": _setup_norms_parallel, @@ -377,6 +386,7 @@ def __init__(self, config, mlp=None, attn=None): self.use_pre_ln = getattr(config, "use_pre_ln", False) self.use_post_ln = getattr(config, "use_post_ln", False) self.use_peri_ln = getattr(config, "use_peri_ln", False) + self.use_offchip_peri_ln = getattr(config, "use_offchip_peri_ln", False) # Forward variation choice self.use_parallel_mlp = getattr(config, "use_parallel_mlp", False)