From 5b90f8bdd63ec51b46048993c3696189350936c5 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Wed, 11 Jun 2025 19:25:44 -0700 Subject: [PATCH 1/5] Define PEFT base class and LoRA transform Signed-off-by: Ananth Subramaniam --- src/megatron/hub/peft/base.py | 201 +++++++ src/megatron/hub/peft/lora.py | 199 +++++++ src/megatron/hub/peft/module_matcher.py | 102 ++-- src/megatron/hub/peft/walk_utils.py | 11 +- tests/unit_tests/peft/test_lora.py | 738 ++++++++++++++++++++++++ 5 files changed, 1181 insertions(+), 70 deletions(-) create mode 100644 src/megatron/hub/peft/base.py create mode 100644 src/megatron/hub/peft/lora.py create mode 100644 tests/unit_tests/peft/test_lora.py diff --git a/src/megatron/hub/peft/base.py b/src/megatron/hub/peft/base.py new file mode 100644 index 0000000000..826f666f53 --- /dev/null +++ b/src/megatron/hub/peft/base.py @@ -0,0 +1,201 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, TypeVar, Union + +import torch +import torch.nn as nn +from megatron.core.transformer.module import MegatronModule + +from megatron.hub.peft.walk_utils import walk + + +logger: logging.Logger = logging.getLogger(__name__) + +ModelType = TypeVar("ModelType", bound=Union[nn.Module, list[MegatronModule]]) + + +@dataclass +class PEFT(ABC): + """Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods. + + This class defines the interface for PEFT methods, which are used to fine-tune + large language models efficiently by modifying only a small subset of the model's + parameters. + + Example: + class MyPEFT(PEFT): + def transform(self, module, name=None, prefix=None): + # Implement the transform logic + pass + + from megatron.hub.models import get_base_model + + peft = MyPEFT() + base_model = get_base_model(model_config) # Returns list[MegatronModule] + adapted_model = peft(base_model, training=True) + """ + + def __post_init__(self) -> None: + """Initialize runtime state after dataclass initialization.""" + self.params_to_save: set[str] = set() + + @abstractmethod + def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: + """Transform a single module according to the PEFT method. + + This method is called for each module in the model during the PEFT application process. + It should be implemented by subclasses to define how individual modules are transformed + for the specific PEFT technique. + + Args: + module (nn.Module): The individual module to be transformed. + name (Optional[str]): The name of the module within the model structure. Defaults to None. + prefix (Optional[str]): A prefix to be added to the module name, typically used for + nested modules. Defaults to None. + + Returns: + nn.Module: The transformed module. This can be the original module with modifications, + a new module replacing the original, or the original module if no + transformation is needed for this specific module. + + Note: + This method is automatically called for each module in the model when the PEFT + instance is applied to the model using the __call__ method. + """ + raise NotImplementedError("The transform method should be implemented by subclasses.") + + def __call__(self, model: ModelType, training: bool = True) -> ModelType: + """Apply the PEFT method to the entire model. + + This method freezes the model parameters and walks through the model + structure, applying the transform method to each module. + + Args: + model: The model to be fine-tuned. Can be a single model or a list of model chunks + (for pipeline parallelism). + training (bool): Whether the model will be used for training. If False, + additional freezing may be applied. Defaults to True. + + Returns: + The same type as the input model, transformed with PEFT applied. + """ + self.freeze_model(model, training=training) + + if isinstance(model, list) and len(model) > 1: + for model_chunk in model: + walk(model_chunk, self.transform) + elif isinstance(model, torch.nn.parallel.DistributedDataParallel): + walk(model.module, self.transform) + else: + if isinstance(model, list): + model_to_walk = model[0] if len(model) == 1 else model + else: + model_to_walk = model + walk(model_to_walk, self.transform) + + if not training: + self.freeze_model(model, training=training) + + # Set model training mode appropriately + if isinstance(model, list): + for model_chunk in model: + model_chunk.train(mode=training) + else: + model.train(mode=training) + + return model + + def freeze_model(self, model: ModelType, training: bool = True) -> None: + """Apply a default freeze method to the model. + + This method freezes all the model parameters. This method can be overridden by subclasses to + implement custom freeze strategies (e.g. freeze only parts of the model) + + Args: + model: The model to be fine-tuned. + training (bool): Whether the model is being used for training. Affects training mode handling. + """ + + def freeze_parameters(module): + """Freeze all parameters in a module.""" + for param in module.parameters(recurse=False): + param.requires_grad = False + return module + + if isinstance(model, list): + for model_chunk in model: + walk(model_chunk, freeze_parameters) + elif isinstance(model, torch.nn.parallel.DistributedDataParallel): + walk(model.module, freeze_parameters) + else: + walk(model, freeze_parameters) + + if training: + if isinstance(model, list): + for model_chunk in model: + model_chunk.train(mode=True) + elif isinstance(model, torch.nn.parallel.DistributedDataParallel): + model.module.train(mode=True) + else: + model.train(mode=True) + + def set_params_to_save(self, model: ModelType) -> None: + """Set parameters to be saved for PEFT checkpointing. + + This method identifies which parameters should be saved during checkpointing + to reduce storage requirements (only adapter parameters, not the full model). + + Args: + model: The model after PEFT has been applied. + """ + # Handle both single models and lists of models + models_to_process = model if isinstance(model, list) else [model] + + self.params_to_save = set() + for model_chunk in models_to_process: + # Add all trainable parameters (adapters) + for name, param in model_chunk.named_parameters(): + if param.requires_grad: + self.params_to_save.add(name) + + # Add any relevant buffers (e.g., running stats from batch norm) + for module_name, module in model_chunk.named_modules(): + if hasattr(module, "track_running_stats"): + for buffer_name, buffer in module.named_buffers(): + if buffer is not None: + self.params_to_save.add(module_name + "." + buffer_name) + + def adapter_key_filter(self, key) -> bool: + """Filter function for adapter parameters during checkpointing. + + This method determines if a parameter should be included in checkpoints. + Used to save only adapter weights, not the full model. + + Args: + key (str or tuple): Parameter name/key to check. Can be a string for regular + checkpointing or a tuple for distributed checkpointing. + + Returns: + bool: True if the parameter should be saved. + """ + # Handle distributed checkpointing where keys can be tuples + if isinstance(key, tuple): + return key[1].requires_grad + + # Handle regular string keys + return key in self.params_to_save or ".adapter." in key or key.endswith(".adapters") diff --git a/src/megatron/hub/peft/lora.py b/src/megatron/hub/peft/lora.py new file mode 100644 index 0000000000..6c5c7006f0 --- /dev/null +++ b/src/megatron/hub/peft/lora.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field +from typing import List, Literal + +import torch +import torch.nn as nn + +from megatron.hub.peft.base import PEFT +from megatron.hub.peft.lora_layers import LinearAdapter, LoRALinear, TELinearAdapter, patch_linear_module +from megatron.hub.peft.module_matcher import ModuleMatcher +from megatron.hub.peft.utils import ParallelLinearAdapter, get_adapter_attributes_from_linear, is_expert_linear +from megatron.hub.utils.import_utils import safe_import + + +logger = logging.getLogger(__name__) + +try: + import transformer_engine.pytorch as te + + HAVE_TE = True +except ImportError: + te = None + HAVE_TE = False + +if torch.cuda.is_available(): + bitsandbytes, HAVE_BNB = safe_import("bitsandbytes") +else: + bitsandbytes = None + HAVE_BNB = False + + +@dataclass +class LoRA(PEFT, ModuleMatcher): + """ + Implements the LoRA (Low-Rank Adaptation) module for parameter-efficient fine-tuning. + + LoRA uses a low-rank projection to adapt the weights of a pre-trained model to a new downstream task. + This class facilitates the application of LoRA to specific modules within the model architecture. + + Args: + target_modules (List[str], optional): A list of module names to apply LoRA to. + Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections + in self-attention. + - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention. + - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP. + - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. + Target modules can also contain wildcards. For example, you can specify + target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv + on the first two layers. + exclude_modules (List[str], optional): A list of module names not to apply LoRa to. It will + match all nn.Linear & nn.Linear-adjacent modules whose name does not match any string in + exclude_modules. If used, will require target_modules to be empty list or None. + dim (int): Dimension of the low-rank projection space. Defaults to 32. + alpha (int): Weighting factor for the low-rank projection. Defaults to 32. + dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0. + dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. + Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre'. + a2a_experimental (bool): Enables the experimental All-to-All (A2A) communication strategy. Defaults to False. + dropout_recompute (bool): Enables dropout recompute using Thunder JIT compilation. When True, + applies thunder.jit() to the dropout layer for memory-efficient training by recomputing + dropout activations during backward pass instead of storing them. + lora_dtype (torch.dtype): Parameter data type for LoRA weights. Default None (will use model's dtype). + + References: + ----------- + Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., & Chen, W. (2021). + LoRA: Low-Rank Adaptation of Large Language Models. arXiv preprint arXiv:2106.09685. + https://arxiv.org/abs/2106.09685 + + ) + """ + + target_modules: List[str] = field( + default_factory=lambda: ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + ) + dim: int = 32 + alpha: int = 32 + dropout: float = 0.0 + dropout_position: Literal["pre", "post"] = "pre" + lora_A_init_method: str = "xavier" + lora_B_init_method: str = "zero" + a2a_experimental: bool = False + lora_dtype: torch.dtype = None + + def __post_init__(self): + """Initialize attributes from parent classes.""" + PEFT.__post_init__(self) + + def transform(self, m: nn.Module, name=None, prefix=None): + """ + Applies LoRA to a specific module within the model architecture. + + Args: + m (nn.Module): The module to apply LoRA to. + name (str, optional): Name of the module (if applicable). Defaults to None. + prefix (str, optional): Prefix for the module name (if applicable). Defaults to None. + + Returns: + nn.Module: The modified module with LoRA applied, or the original module if not a target. + """ + + if (ans := self.match(m, name, prefix)) is not None: + (match, full_name) = ans + if isinstance(m, nn.Linear) or m.__class__ == te.Linear: + # Will use the `patch_linear_module` function if: + # - is FSDP v1 + # - is DTensor (has _local_tensor attribute) + # - has quant_state attribute + if hasattr(m.weight.data, "_local_tensor") or ( + getattr(m, "quant_state", None) is not None + and m.quant_state.__class__ == bitsandbytes.functional.QuantState + ): + lora_cls = patch_linear_module + elif HAVE_TE and m.__class__ == te.Linear: + lora_cls = TELinearAdapter + else: + lora_cls = LinearAdapter + + return lora_cls( + m, + dim=self.dim, + alpha=self.alpha, + dropout=self.dropout, + lora_A_init_method=self.lora_A_init_method, + lora_dtype=self.lora_dtype, + ) + + input_is_parallel, in_features, out_features, disable_sp_comm = get_adapter_attributes_from_linear(m) + logging.info(f"Adding lora to: {full_name}") + adapter = ParallelLinearAdapter( + in_features, + out_features, + self.dim, + base_linear_name=full_name, + activation="identity", + norm_type=None, + column_init_method=self.lora_A_init_method, + row_init_method=self.lora_B_init_method, + gather_output=False, + input_is_parallel=input_is_parallel, + dropout=self.dropout, + dropout_position=self.dropout_position, + model_parallel_config=getattr(m, "config", None), + alpha=self.alpha, + is_expert=is_expert_linear(full_name), + a2a_experimental=self.a2a_experimental, + disable_sequence_parallel_comm=disable_sp_comm, + ) + return LoRALinear(m, adapter) + return m + + +class LoRAMerge(PEFT): + """ + Implements the LoRA weight merge for parameter-efficient fine-tuning. + """ + + @torch.no_grad() + def transform(self, m: nn.Module, name=None, prefix=None): + """ + Merges the LoRA adapter with the base model weights. + + Args: + m (nn.Module): The module to apply LoRA merge to. + name (str, optional): Name of the module to merge. Defaults to None. + prefix (str, optional): Prefix for the module name. Defaults to None. + + Returns: + nn.Module: The modified module with the LoRA adapter merged into the base model weights. + """ + + if not isinstance(m, LoRALinear): + return m + logging.info(f"merging {(prefix if prefix else '') + '.' + (name if name else '')}") + base_weight = m.to_wrap.weight + lora_weight = ( + m.adapter.alpha + / m.adapter.dim + * m.adapter.linear_out.weight.to(base_weight.device) + @ m.adapter.linear_in.weight.to(base_weight.device) + ) + merged_weight = base_weight + lora_weight + m.to_wrap.weight.data = merged_weight + return m diff --git a/src/megatron/hub/peft/module_matcher.py b/src/megatron/hub/peft/module_matcher.py index 318fc06a8b..e5997a42d1 100644 --- a/src/megatron/hub/peft/module_matcher.py +++ b/src/megatron/hub/peft/module_matcher.py @@ -14,15 +14,14 @@ from collections import defaultdict from dataclasses import dataclass, field -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Set -import torch.nn as nn from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear +from torch import nn from megatron.hub.peft.utils import wildcard_match from megatron.hub.utils.import_utils import safe_import_from - TEColumnParallelLinear, HAVE_TE_COL_LINEAR = safe_import_from( "megatron.core.extensions.transformer_engine", "TEColumnParallelLinear" ) @@ -38,47 +37,27 @@ @dataclass class ModuleMatcher: - """Module matcher for parameter-efficient fine-tuning (PEFT) applications. - - This class facilitates the identification and selection of modules within a model - architecture for applying PEFT techniques like LoRA (Low-Rank Adaptation). It provides - flexible matching patterns including wildcards, exact module names, and type-based - selection. + """ + Implements the LoRA (Low-Rank Adaptation) module for parameter-efficient fine-tuning. - The matcher supports three modes of operation: - 1. Canonical mapping: Uses predefined mappings to match modules - 2. Target modules: Matches against a list of target module names/patterns - 3. Exclude modules: Matches all linear layers except those explicitly excluded + LoRA uses a low-rank projection to adapt the weights of a pre-trained model to a new downstream task. + This class facilitates the application of LoRA to specific modules within the model architecture. Args: - target_modules (List[str], optional): A list of module names to apply PEFT to. + target_modules (List[str], optional): A list of module names to apply LoRA to. Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. - - 'linear_qkv': Apply to the fused linear layer used for query, key, and value projections + - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections in self-attention. - - 'linear_proj': Apply to the linear layer used for projecting the output of self-attention. - - 'linear_fc1': Apply to the first fully-connected layer in MLP. - - 'linear_fc2': Apply to the second fully-connected layer in MLP. + - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention. + - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP. + - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. Target modules can also contain wildcards. For example, you can specify - target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add adapters to - only linear_qkv on the first two layers. - exclude_modules (List[str], optional): A list of module names to exclude from matching. - Only used when neither canonical_mapping nor target_modules are specified. - canonical_mapping (Dict[str, Set], optional): A mapping from pattern names to sets of - module types. Used for more complex matching scenarios. - - Example: - >>> # Match specific modules by name - >>> matcher = ModuleMatcher(target_modules=['linear_qkv', 'linear_proj']) - >>> - >>> # Match with wildcards - >>> matcher = ModuleMatcher(target_modules=['*.layers.*.linear_qkv']) - >>> - >>> # Exclude specific modules (matches all linear layers except excluded ones) - >>> matcher = ModuleMatcher(exclude_modules=['linear_fc1']) + target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv + on the first two layers. """ target_modules: List[str] = field( - default_factory=lambda: ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + default_factory=lambda: ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'] ) exclude_modules: List[str] = field(default_factory=list) canonical_mapping: Dict[str, Set] = field(default_factory=lambda: defaultdict(set)) @@ -86,66 +65,55 @@ class ModuleMatcher: def match( self, m: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None ) -> Optional[Tuple[str, str]]: - """Determine whether a given module matches specified target patterns. + """ + Determines whether a given module matches specified target patterns. This function checks if the provided module `m` should be included based on predefined mapping rules (`canonical_mapping`, `target_modules`, and `exclude_modules`). It returns - the matching pattern and full name if a match is found; otherwise, it returns `None`. + the matching pattern if a match is found; otherwise, it returns `None`. Args: m (nn.Module): The module being checked. - name (str, optional): The module's name. Defaults to None. - prefix (str, optional): A prefix to be used in constructing `full_name`. Defaults to None. + name (str, optional): The module's name. + prefix (str, optional): A prefix to be used in constructing `full_name`. Returns: Optional[Tuple[str, str]]: A tuple containing (matching_pattern, full_name) if a match is found; otherwise, `None`. Matching Logic: - 1) If `canonical_mapping` is defined, it checks: - - Whether `name` exactly matches a pattern. - - Whether `full_name` matches any wildcard pattern in `canonical_mapping`. - 2) If `target_modules` is defined, it follows the same logic as `canonical_mapping`. - 3) If neither `canonical_mapping` nor `target_modules` are defined, it ensures: - - `name` is not in `exclude_modules`. - - `full_name` does not match any `exclude_modules` patterns. - - `m` is an instance of a supported linear layer type. + 1) If `canonical_mapping` is defined, it checks: + - Whether `name` exactly matches a pattern. + - Whether `full_name` matches any regex pattern in `canonical_mapping`. + 2) If `target_modules` is defined, it follows the same logic as `canonical_mapping`. + 3) If neither `canonical_mapping` nor `target_modules` are defined, it ensures: + - `name` is not in `exclude_modules`. + - `full_name` does not match any `target_modules` patterns. + - `m` is an instance of `nn.Linear`. Notes: - - `exclude_modules` should only be non-empty if neither `canonical_mapping` nor - `target_modules` are set. - - The function asserts that `exclude_modules` is empty when using `canonical_mapping` - or `target_modules`. - - Example: - >>> matcher = ModuleMatcher(target_modules=['*.linear_qkv']) - >>> result = matcher.match(module, 'linear_qkv', 'model.layers.0.self_attention') - >>> if result: - ... pattern, full_name = result - ... print(f"Matched {full_name} with pattern {pattern}") + - `exclude_modules` should only be non-empty if neither `canonical_mapping` nor `target_modules` are set. + - The function asserts that `exclude_modules` is empty when using `canonical_mapping` or `target_modules`. """ - full_name = f"{prefix}.{name}" if prefix else name + full_name = f"{prefix}.{name}" if prefix else name if len(self.canonical_mapping or []) > 0: """ Find the element in canonical_mapping which 1) matches the current `name` exactly, OR - 2) matches the current `full_name` with wildcard + 2) matches the current `full_name` with regex match is None if current module name doesn't match the specified targets. """ - assert len(self.exclude_modules) == 0, "exclude_modules should be empty when using canonical_mapping" + assert len(self.exclude_modules) == 0 for pattern in self.canonical_mapping: if name == pattern or wildcard_match(pattern, full_name): return (pattern, full_name) - elif len(self.target_modules or []) > 0: - assert len(self.exclude_modules) == 0, "exclude_modules should be empty when using target_modules" + assert len(self.exclude_modules) == 0 for pattern in self.target_modules: if name == pattern or wildcard_match(pattern, full_name): return (pattern, full_name) - else: - # Default mode: match all linear layers except excluded ones linear_types = [ColumnParallelLinear, RowParallelLinear, nn.Linear] if HAVE_TE_COL_LINEAR: linear_types.append(TEColumnParallelLinear) @@ -156,10 +124,10 @@ def match( linear_types = tuple(linear_types) if ( - name not in self.exclude_modules + not name in self.exclude_modules and not any(wildcard_match(pattern, full_name) for pattern in self.exclude_modules) and isinstance(m, linear_types) ): return (name, full_name) - return None + return None \ No newline at end of file diff --git a/src/megatron/hub/peft/walk_utils.py b/src/megatron/hub/peft/walk_utils.py index 9db01abc00..28812191b7 100644 --- a/src/megatron/hub/peft/walk_utils.py +++ b/src/megatron/hub/peft/walk_utils.py @@ -319,16 +319,21 @@ def _map_module_dict( module_dict = func(module_dict, **f_kwargs) mapped_modules = {} - for i, (name, module) in enumerate(module_dict.items()): - kwargs["i"] = i - kwargs["name"] = name + prefix = kwargs.get("name", "") if not kwargs.get("prefix", "") else f"{kwargs['prefix']}.{kwargs['name']}" + kwargs.pop("i", None) + kwargs.pop("name", None) + kwargs.pop("prefix", None) + for i, (name, module) in enumerate(module_dict.items()): mapped_modules[name] = map( module, func, recurse=recurse, leaf_only=leaf_only, transformed_modules=transformed_modules, + i=i, + name=name, + prefix=prefix, **kwargs, ) diff --git a/tests/unit_tests/peft/test_lora.py b/tests/unit_tests/peft/test_lora.py new file mode 100644 index 0000000000..fc9111d3b2 --- /dev/null +++ b/tests/unit_tests/peft/test_lora.py @@ -0,0 +1,738 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +from unittest.mock import patch + +import megatron.core.parallel_state as parallel_state +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from megatron.core.transformer.module import MegatronModule + +from megatron.hub.models import get_base_model +from megatron.hub.models.gpt import GPTConfig +from megatron.hub.peft.lora import LoRA, LoRAMerge +from megatron.hub.peft.lora_layers import LinearAdapter, LoRALinear + + +class SimpleModel(nn.Module): + """Simple test model with various linear layers.""" + + def __init__(self): + super().__init__() + self.embedding = nn.Embedding(1000, 512) + self.linear_qkv = nn.Linear(512, 1536) # Should be matched + self.linear_proj = nn.Linear(512, 512) # Should be matched + self.linear_fc1 = nn.Linear(512, 2048) # Should be matched + self.linear_fc2 = nn.Linear(2048, 512) # Should be matched + self.output_projection = nn.Linear(512, 1000) # Should NOT be matched (not in target_modules) + self.layernorm = nn.LayerNorm(512) + + +class NestedModel(nn.Module): + """Model with nested structure for testing pattern matching.""" + + def __init__(self): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.ModuleDict( + { + "attention": nn.ModuleDict( + { + "linear_qkv": nn.Linear(512, 1536), + "linear_proj": nn.Linear(512, 512), + } + ), + "mlp": nn.ModuleDict( + { + "linear_fc1": nn.Linear(512, 2048), + "linear_fc2": nn.Linear(2048, 512), + } + ), + } + ) + for _ in range(2) + ] + ) + + +class TestLoRA: + """Test suite for LoRA PEFT implementation.""" + + def test_lora_initialization(self): + """Test LoRA class initialization with default and custom parameters.""" + # Test default initialization + lora = LoRA() + assert lora.target_modules == ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] + assert lora.dim == 32 + assert lora.alpha == 32 + assert lora.dropout == 0.0 + assert lora.dropout_position == "pre" + assert lora.lora_A_init_method == "xavier" + assert lora.lora_B_init_method == "zero" + + # Test custom initialization + custom_lora = LoRA( + target_modules=["linear_qkv"], + dim=16, + alpha=16, + dropout=0.1, + dropout_position="post", + lora_A_init_method="uniform", + ) + assert custom_lora.target_modules == ["linear_qkv"] + assert custom_lora.dim == 16 + assert custom_lora.alpha == 16 + assert custom_lora.dropout == 0.1 + assert custom_lora.dropout_position == "post" + assert custom_lora.lora_A_init_method == "uniform" + + def test_lora_transform_simple_model(self): + """Test LoRA transformation on a simple model.""" + model = SimpleModel() + lora = LoRA(target_modules=["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]) + + # Apply LoRA + transformed_model = lora(model, training=True) + + # Check that target modules were transformed to LinearAdapter + assert isinstance(transformed_model.linear_qkv, LinearAdapter) + assert isinstance(transformed_model.linear_proj, LinearAdapter) + assert isinstance(transformed_model.linear_fc1, LinearAdapter) + assert isinstance(transformed_model.linear_fc2, LinearAdapter) + + # Check that non-target modules were not transformed + assert isinstance(transformed_model.output_projection, nn.Linear) + assert isinstance(transformed_model.embedding, nn.Embedding) + assert isinstance(transformed_model.layernorm, nn.LayerNorm) + + def test_lora_transform_with_exclude_modules(self): + """Test LoRA transformation with exclude_modules parameter.""" + model = SimpleModel() + # Use only exclude_modules (no target_modules) to test exclusion behavior + lora = LoRA( + target_modules=[], # Empty target_modules to use exclude mode + exclude_modules=["linear_fc2", "output_projection"], # Exclude specific linear modules + ) + + # Apply LoRA + transformed_model = lora(model, training=True) + + # Check that excluded linear modules were not transformed + assert isinstance(transformed_model.linear_fc2, nn.Linear) + assert isinstance(transformed_model.output_projection, nn.Linear) + + # Check that non-excluded linear modules were transformed + # (In exclude mode, all linear layers except excluded ones should be transformed) + assert isinstance(transformed_model.linear_qkv, LinearAdapter) + assert isinstance(transformed_model.linear_proj, LinearAdapter) + assert isinstance(transformed_model.linear_fc1, LinearAdapter) + + # Non-linear modules should never be transformed regardless + assert isinstance(transformed_model.embedding, nn.Embedding) + assert isinstance(transformed_model.layernorm, nn.LayerNorm) + + def test_lora_transform_nested_model(self): + """Test LoRA transformation on nested model structures.""" + model = NestedModel() + lora = LoRA(target_modules=["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]) + + # Apply LoRA + transformed_model = lora(model, training=True) + + # Check that nested target modules were transformed + for layer in transformed_model.layers: + assert isinstance(layer["attention"]["linear_qkv"], LinearAdapter) + assert isinstance(layer["attention"]["linear_proj"], LinearAdapter) + assert isinstance(layer["mlp"]["linear_fc1"], LinearAdapter) + assert isinstance(layer["mlp"]["linear_fc2"], LinearAdapter) + + def test_lora_wildcard_matching(self): + """Test LoRA transformation with wildcard patterns.""" + model = NestedModel() + # Only apply LoRA to first layer's attention modules + lora = LoRA(target_modules=["layers.0.attention.*"]) + + # Apply LoRA + transformed_model = lora(model, training=True) + + # Check first layer attention modules are transformed + assert isinstance(transformed_model.layers[0]["attention"]["linear_qkv"], LinearAdapter) + assert isinstance(transformed_model.layers[0]["attention"]["linear_proj"], LinearAdapter) + + # Check first layer MLP modules are NOT transformed + assert isinstance(transformed_model.layers[0]["mlp"]["linear_fc1"], nn.Linear) + assert isinstance(transformed_model.layers[0]["mlp"]["linear_fc2"], nn.Linear) + + # Check second layer modules are NOT transformed + assert isinstance(transformed_model.layers[1]["attention"]["linear_qkv"], nn.Linear) + assert isinstance(transformed_model.layers[1]["attention"]["linear_proj"], nn.Linear) + assert isinstance(transformed_model.layers[1]["mlp"]["linear_fc1"], nn.Linear) + assert isinstance(transformed_model.layers[1]["mlp"]["linear_fc2"], nn.Linear) + + def test_lora_adapter_properties(self): + """Test that LoRA adapters have correct properties.""" + model = SimpleModel() + lora = LoRA(dim=16, alpha=32, dropout=0.1) + + # Apply LoRA + transformed_model = lora(model, training=True) + + # Check adapter properties + adapter = transformed_model.linear_qkv + assert hasattr(adapter, "dim") + assert hasattr(adapter, "scale") + assert hasattr(adapter, "lora_a") + assert hasattr(adapter, "lora_b") + assert hasattr(adapter, "dropout") + + assert adapter.dim == 16 + assert adapter.scale == 32 / 16 # alpha / dim + assert adapter.dropout.p == 0.1 + + def test_lora_parameter_freezing(self): + """Test that base model parameters are frozen and adapter parameters are trainable.""" + model = SimpleModel() + lora = LoRA() + + # Apply LoRA + transformed_model = lora(model, training=True) + + # Check that original weights are frozen + linear_adapter = transformed_model.linear_qkv + assert not linear_adapter.weight.requires_grad + if linear_adapter.bias is not None: + assert not linear_adapter.bias.requires_grad + + # Check that LoRA parameters are trainable + assert linear_adapter.lora_a.weight.requires_grad + assert linear_adapter.lora_b.weight.requires_grad + + def test_lora_forward_pass(self): + """Test that LoRA adapted models can perform forward passes.""" + model = SimpleModel() + lora = LoRA(dim=8) + + # Apply LoRA + transformed_model = lora(model, training=True) + + # Test forward pass + batch_size, seq_len = 2, 10 + input_ids = torch.randint(0, 1000, (batch_size, seq_len)) + + with torch.no_grad(): + embeddings = transformed_model.embedding(input_ids) # [batch, seq, 512] + + # Test each adapted layer + qkv_out = transformed_model.linear_qkv(embeddings) # Should work + proj_out = transformed_model.linear_proj(embeddings) # Should work + fc1_out = transformed_model.linear_fc1(embeddings) # Should work + fc2_out = transformed_model.linear_fc2(fc1_out) # Should work + + assert qkv_out.shape == (batch_size, seq_len, 1536) + assert proj_out.shape == (batch_size, seq_len, 512) + assert fc1_out.shape == (batch_size, seq_len, 2048) + assert fc2_out.shape == (batch_size, seq_len, 512) + + def test_lora_training_vs_inference_mode(self): + """Test LoRA behavior in training vs inference mode.""" + model = SimpleModel() + lora = LoRA() + + # Test training mode + training_model = lora(model, training=True) + assert training_model.training + + # Test inference mode + inference_model = lora(model, training=False) + assert not inference_model.training + + @patch("megatron.hub.peft.lora.HAVE_TE", True) + @patch("megatron.hub.peft.lora.te") + def test_lora_te_linear_support(self, mock_te): + """Test LoRA support for Transformer Engine Linear layers.""" + + # Create the TE Linear type and an actual instance + class MockTELinear(nn.Module): + def __init__(self): + super().__init__() + + # Create a simple weight mock that doesn't have _local_tensor + class MockWeightData: + pass + + class MockWeight: + def __init__(self): + self.data = MockWeightData() + + self.weight = MockWeight() + self.quant_state = None + + # Set the mock_te.Linear to our MockTELinear class + mock_te.Linear = MockTELinear + + # Create an actual instance of our mock TE Linear + te_linear_instance = MockTELinear() + + # Create model with mock TE linear + model = nn.Module() + model.te_linear = te_linear_instance + + lora = LoRA(target_modules=["te_linear"]) + + # Mock the TELinearAdapter to avoid TE dependencies + with patch("megatron.hub.peft.lora.TELinearAdapter") as mock_te_adapter: + mock_te_adapter.return_value = te_linear_instance + + # Should create TELinearAdapter + _ = lora(model, training=True) + mock_te_adapter.assert_called_once() + + def test_lora_list_model_support(self): + """Test LoRA support for list of model chunks (pipeline parallelism).""" + # Create list of model chunks + model_chunks = [SimpleModel() for _ in range(3)] + lora = LoRA() + + # Apply LoRA to list of models + transformed_chunks = lora(model_chunks, training=True) + + # Should return list of same length + assert isinstance(transformed_chunks, list) + assert len(transformed_chunks) == 3 + + # Each chunk should have LoRA applied + for chunk in transformed_chunks: + assert isinstance(chunk.linear_qkv, LinearAdapter) + assert isinstance(chunk.linear_proj, LinearAdapter) + assert isinstance(chunk.linear_fc1, LinearAdapter) + assert isinstance(chunk.linear_fc2, LinearAdapter) + + +class TestLoRAMerge: + """Test suite for LoRA merge functionality.""" + + def test_lora_merge_initialization(self): + """Test LoRAMerge class initialization.""" + merge = LoRAMerge() + assert hasattr(merge, "transform") + + def test_lora_merge_transform(self): + """Test LoRA weight merging behavior with LinearAdapter instances.""" + # Create model and apply LoRA + model = SimpleModel() + lora = LoRA(dim=8, alpha=16) + adapted_model = lora(model, training=True) + + # Get original weights + original_weight = adapted_model.linear_qkv.weight.data.clone() + + # Create merge instance and apply + merge = LoRAMerge() + merged_model = merge(adapted_model, training=False) + + # Note: LoRAMerge only handles LoRALinear instances (Megatron modules), + # not LinearAdapter instances (regular nn.Linear modules). + # So for SimpleModel, the modules should remain as LinearAdapter unchanged. + assert isinstance(merged_model.linear_qkv, LinearAdapter) + + # Weights should be unchanged since merge doesn't apply to LinearAdapter + merged_weight = merged_model.linear_qkv.weight.data + assert torch.equal(original_weight, merged_weight) + + def test_lora_merge_with_lora_linear(self): + """Test LoRA weight merging with LoRALinear instances (the intended use case).""" + # Create a mock base module (representing a Megatron parallel module) + base_module = nn.Linear(64, 128) + original_weight = base_module.weight.data.clone() + + # Create a mock LoRA adapter that mimics ParallelLinearAdapter structure + class MockAdapter(nn.Module): + def __init__(self): + super().__init__() + self.alpha = 16 + self.dim = 8 + self.linear_in = nn.Linear(64, 8, bias=False) + self.linear_out = nn.Linear(8, 128, bias=False) + + # Initialize with small non-zero values to see merge effect + with torch.no_grad(): + self.linear_in.weight.data.fill_(0.1) + self.linear_out.weight.data.fill_(0.05) + + adapter = MockAdapter() + + # Create LoRALinear instance (what LoRA creates for Megatron modules) + lora_linear = LoRALinear(base_module, adapter) + + # Create merge instance and apply + merge = LoRAMerge() + merged_result = merge.transform(lora_linear) + + # Should return the LoRALinear wrapper (matches NeMo behavior) + assert merged_result is lora_linear + + # The underlying weight should be modified (merged) + merged_weight = lora_linear.to_wrap.weight.data + assert not torch.equal(original_weight, merged_weight) + + # The change should equal the LoRA adaptation + expected_lora_weight = (adapter.alpha / adapter.dim) * (adapter.linear_out.weight @ adapter.linear_in.weight) + expected_merged = original_weight + expected_lora_weight + assert torch.allclose(merged_weight, expected_merged, atol=1e-6) + + def test_lora_merge_non_lora_modules(self): + """Test that non-LoRA modules are unchanged during merge.""" + model = SimpleModel() + merge = LoRAMerge() + + # Apply merge to model without LoRA (should be no-op) + original_linear = model.linear_qkv + merged_model = merge(model, training=False) + + # Should be unchanged + assert merged_model.linear_qkv is original_linear + + +class TestLoRAIntegration: + """Integration tests for LoRA functionality.""" + + def test_lora_full_pipeline(self): + """Test complete LoRA application and merge pipeline.""" + # Create base model + model = SimpleModel() + original_weights = {} + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + original_weights[name] = module.weight.data.clone() + + # Apply LoRA + lora = LoRA(dim=4, alpha=8) + adapted_model = lora(model, training=True) + + # Verify LoRA was applied + assert isinstance(adapted_model.linear_qkv, LinearAdapter) + + # Perform training step (mock) + optimizer = torch.optim.Adam(adapted_model.parameters()) + + # Forward pass + input_ids = torch.randint(0, 1000, (2, 10)) + embeddings = adapted_model.embedding(input_ids) + output = adapted_model.linear_qkv(embeddings) + loss = output.sum() + + # Backward pass + loss.backward() + optimizer.step() + + # Merge LoRA weights + merge = LoRAMerge() + merged_model = merge(adapted_model, training=False) + + # Note: LoRAMerge only handles LoRALinear instances (Megatron modules), + # not LinearAdapter instances (regular nn.Linear modules). + # So for SimpleModel, merge should be a no-op. + assert isinstance(merged_model.linear_qkv, LinearAdapter) + + # The module should be unchanged since LoRAMerge doesn't affect LinearAdapter + assert merged_model.linear_qkv is adapted_model.linear_qkv + + def test_lora_parameter_efficiency(self): + """Test that LoRA significantly reduces trainable parameters.""" + model = SimpleModel() + + # Count original parameters + original_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + # Apply LoRA + lora = LoRA(dim=8) # Small rank for efficiency + adapted_model = lora(model, training=True) + + # Count trainable parameters after LoRA + lora_params = sum(p.numel() for p in adapted_model.parameters() if p.requires_grad) + + # LoRA should significantly reduce trainable parameters + assert lora_params < original_params + efficiency_ratio = lora_params / original_params + assert efficiency_ratio < 0.1 + + def test_lora_reproducibility(self): + """Test that LoRA application is deterministic.""" + torch.manual_seed(42) + model1 = SimpleModel() + lora1 = LoRA(dim=8, alpha=16) + adapted_model1 = lora1(model1, training=True) + + torch.manual_seed(42) + model2 = SimpleModel() + lora2 = LoRA(dim=8, alpha=16) + adapted_model2 = lora2(model2, training=True) + + # LoRA weights should be identical with same seed + lora_a_1 = adapted_model1.linear_qkv.lora_a.weight.data + lora_a_2 = adapted_model2.linear_qkv.lora_a.weight.data + assert torch.equal(lora_a_1, lora_a_2) + + lora_b_1 = adapted_model1.linear_qkv.lora_b.weight.data + lora_b_2 = adapted_model2.linear_qkv.lora_b.weight.data + assert torch.equal(lora_b_1, lora_b_2) + + +class TestLoRAMegatronIntegration: + """Integration tests for LoRA with real Megatron models.""" + + @pytest.fixture(autouse=True) + def setup_and_teardown_parallel_state(self): + """Setup and teardown parallel state for Megatron tests.""" + + if not dist.is_initialized(): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + device_count = torch.cuda.device_count() + if device_count > 0: + torch.cuda.set_device(0) + + init_process_group_kwargs = { + "backend": "nccl" if device_count > 0 else "gloo", + "world_size": 1, + "rank": 0, + "timeout": datetime.timedelta(minutes=30), + } + + dist.init_process_group(**init_process_group_kwargs) + + assert dist.is_initialized(), "Distributed backend not initialized" + if not parallel_state.model_parallel_is_initialized(): + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=1, + ) + + assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" + from megatron.hub.training.initialize import _set_random_seed + + _set_random_seed( + seed_=1234, + data_parallel_random_init=False, + te_rng_tracker=True, + inference_rng_tracker=False, + ) + + yield + + try: + if parallel_state.model_parallel_is_initialized(): + parallel_state.destroy_model_parallel() + if dist.is_initialized(): + dist.destroy_process_group() + # Clean up environment variables + for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "LOCAL_RANK", "WORLD_SIZE"]: + os.environ.pop(key, None) + except (NameError, AttributeError, RuntimeError): + pass + + def test_lora_with_gpt_model(self): + """Test LoRA application to a real GPT model from get_base_model.""" + + # Create a minimal GPT configuration + config = GPTConfig( + num_layers=2, + hidden_size=128, + num_attention_heads=2, + vocab_size=1000, + ffn_hidden_size=256, + ) + + base_model = get_base_model(config) + + # Verify we got a list of Megatron modules + assert isinstance(base_model, list) + assert len(base_model) > 0 + assert all(isinstance(chunk, MegatronModule) for chunk in base_model) + + # Ensure model is on CUDA if available + if torch.cuda.is_available(): + base_model = [chunk.cuda() for chunk in base_model] + + # Create LoRA instance targeting linear layers + lora = LoRA( + target_modules=["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"], dim=8, alpha=16, dropout=0.0 + ) + + # Apply LoRA to the model + adapted_model = lora(base_model, training=True) + + # Verify we still have a list of the same length + assert isinstance(adapted_model, list) + assert len(adapted_model) == len(base_model) + + # Verify that LoRA was applied to target modules + found_lora_modules = [] + for chunk in adapted_model: + for name, module in chunk.named_modules(): + if isinstance(module, LoRALinear): + found_lora_modules.append(name) + + # Should have found some LoRA modules + assert len(found_lora_modules) > 0, "No LoRA modules found in adapted model" + + # Verify parameter states + total_params = 0 + trainable_params = 0 + for chunk in adapted_model: + for param in chunk.parameters(): + total_params += param.numel() + if param.requires_grad: + trainable_params += param.numel() + + # Should have significantly fewer trainable parameters than total + assert trainable_params < total_params + efficiency_ratio = trainable_params / total_params + assert efficiency_ratio < 0.3, f"LoRA should be parameter efficient, got ratio: {efficiency_ratio}" + + def test_lora_forward_pass_with_megatron_model(self): + """Test forward pass through LoRA-adapted Megatron model.""" + + # Create minimal config for fast testing + config = GPTConfig( + num_layers=1, + hidden_size=64, + num_attention_heads=2, + vocab_size=100, + ffn_hidden_size=128, + ) + + # Get and adapt model + base_model = get_base_model(config) + + # Ensure model is on CUDA if available + if torch.cuda.is_available(): + base_model = [chunk.cuda() for chunk in base_model] + + lora = LoRA(dim=4, alpha=8) + adapted_model = lora(base_model, training=True) + + # Test forward pass with proper Megatron input format + batch_size, seq_len = 2, 8 + + # Get model device (model is on CUDA, inputs need to match) + model_device = next(adapted_model[0].parameters()).device + + # Create input tensors in the format expected by Megatron models + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=model_device) + position_ids = torch.arange(seq_len, dtype=torch.long, device=model_device).unsqueeze(0).expand(batch_size, -1) + + # Create 4D causal attention mask [batch_size, 1, seq_len, seq_len] + # True values are masked out (don't attend), False values attend + attention_mask = torch.tril(torch.ones(seq_len, seq_len, device=model_device)) < 0.5 + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) + + # Run forward pass using the standard codebase pattern + forward_args = { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + } + + with torch.no_grad(): + for chunk in adapted_model: + output = chunk(**forward_args) + + # Verify output shape and that LoRA is active + if isinstance(output, tuple): + logits = output[0] + else: + logits = output + + expected_shape = (batch_size, seq_len, config.vocab_size) + assert logits.shape == expected_shape, f"Expected {expected_shape}, got {logits.shape}" + + # Count LoRA adaptations + lora_count = sum(1 for _, m in chunk.named_modules() if isinstance(m, LoRALinear)) + assert lora_count > 0, "Should have LoRA adaptations applied" + + def test_lora_merge_with_megatron_model(self): + """Test LoRA merge functionality with Megatron models.""" + + # Create minimal config + config = GPTConfig( + num_layers=1, + hidden_size=64, + num_attention_heads=2, + vocab_size=100, + ffn_hidden_size=128, + ) + + # Get base model and apply LoRA + base_model = get_base_model(config) + + # Move model to CUDA if available + if torch.cuda.is_available(): + base_model = [chunk.cuda() for chunk in base_model] + + lora = LoRA(dim=4, alpha=8) + adapted_model = lora(base_model, training=True) + + # Count LoRA modules before merge + lora_modules_before = 0 + original_weights = {} + for chunk in adapted_model: + for name, module in chunk.named_modules(): + if isinstance(module, LoRALinear): + lora_modules_before += 1 + # Store original weights to verify they change after merge + original_weights[name] = module.to_wrap.weight.data.clone() + + assert lora_modules_before > 0, "Should have some LoRA modules before merge" + + # Simulate training by making adapter weights non-zero + # (LoRA adapters start at zero, so merge would be no-op without this) + for chunk in adapted_model: + for name, module in chunk.named_modules(): + if isinstance(module, LoRALinear): + # Make adapter weights non-zero to simulate training + with torch.no_grad(): + module.adapter.linear_in.weight.data.fill_(0.1) + module.adapter.linear_out.weight.data.fill_(0.05) + + # Apply merge + merge = LoRAMerge() + merged_model = merge(adapted_model, training=False) + + # Count LoRA modules after merge + lora_modules_after = 0 + weights_changed = 0 + for chunk in merged_model: + for name, module in chunk.named_modules(): + if isinstance(module, LoRALinear): + lora_modules_after += 1 + # Check if weights were actually merged (changed) + if name in original_weights: + if not torch.equal(original_weights[name], module.to_wrap.weight.data): + weights_changed += 1 + + # LoRAMerge keeps the LoRALinear wrappers but merges the weights + assert lora_modules_after == lora_modules_before, "LoRAMerge keeps LoRALinear wrappers" + assert weights_changed > 0, "LoRAMerge should change the underlying weights" From b0f815db07a7188713681bc4f1aa12a61ed2f740 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Thu, 12 Jun 2025 09:19:48 -0700 Subject: [PATCH 2/5] run integration on gpu only Signed-off-by: Ananth Subramaniam --- tests/unit_tests/peft/test_lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/peft/test_lora.py b/tests/unit_tests/peft/test_lora.py index fc9111d3b2..fa5ad5d925 100644 --- a/tests/unit_tests/peft/test_lora.py +++ b/tests/unit_tests/peft/test_lora.py @@ -494,6 +494,7 @@ def test_lora_reproducibility(self): assert torch.equal(lora_b_1, lora_b_2) +@pytest.mark.run_only_on("GPU") class TestLoRAMegatronIntegration: """Integration tests for LoRA with real Megatron models.""" From fd014ece9efaa6846e02a8313c4a0bbd34aab2b4 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Thu, 12 Jun 2025 17:23:00 -0700 Subject: [PATCH 3/5] rebase Signed-off-by: Ananth Subramaniam --- src/megatron/hub/peft/lora.py | 46 ++++++++++++------------- src/megatron/hub/peft/module_matcher.py | 9 ++--- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/megatron/hub/peft/lora.py b/src/megatron/hub/peft/lora.py index 6c5c7006f0..6a8e863ea3 100644 --- a/src/megatron/hub/peft/lora.py +++ b/src/megatron/hub/peft/lora.py @@ -14,7 +14,7 @@ import logging from dataclasses import dataclass, field -from typing import List, Literal +from typing import List, Literal, Optional import torch import torch.nn as nn @@ -101,7 +101,7 @@ def __post_init__(self): """Initialize attributes from parent classes.""" PEFT.__post_init__(self) - def transform(self, m: nn.Module, name=None, prefix=None): + def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: """ Applies LoRA to a specific module within the model architecture. @@ -114,25 +114,25 @@ def transform(self, m: nn.Module, name=None, prefix=None): nn.Module: The modified module with LoRA applied, or the original module if not a target. """ - if (ans := self.match(m, name, prefix)) is not None: + if (ans := self.match(module, name, prefix)) is not None: (match, full_name) = ans - if isinstance(m, nn.Linear) or m.__class__ == te.Linear: + if isinstance(module, nn.Linear) or module.__class__ == te.Linear: # Will use the `patch_linear_module` function if: # - is FSDP v1 # - is DTensor (has _local_tensor attribute) # - has quant_state attribute - if hasattr(m.weight.data, "_local_tensor") or ( - getattr(m, "quant_state", None) is not None - and m.quant_state.__class__ == bitsandbytes.functional.QuantState + if hasattr(module.weight.data, "_local_tensor") or ( + getattr(module, "quant_state", None) is not None + and module.quant_state.__class__ == bitsandbytes.functional.QuantState ): lora_cls = patch_linear_module - elif HAVE_TE and m.__class__ == te.Linear: + elif HAVE_TE and module.__class__ == te.Linear: lora_cls = TELinearAdapter else: lora_cls = LinearAdapter return lora_cls( - m, + module, dim=self.dim, alpha=self.alpha, dropout=self.dropout, @@ -140,7 +140,7 @@ def transform(self, m: nn.Module, name=None, prefix=None): lora_dtype=self.lora_dtype, ) - input_is_parallel, in_features, out_features, disable_sp_comm = get_adapter_attributes_from_linear(m) + input_is_parallel, in_features, out_features, disable_sp_comm = get_adapter_attributes_from_linear(module) logging.info(f"Adding lora to: {full_name}") adapter = ParallelLinearAdapter( in_features, @@ -155,14 +155,14 @@ def transform(self, m: nn.Module, name=None, prefix=None): input_is_parallel=input_is_parallel, dropout=self.dropout, dropout_position=self.dropout_position, - model_parallel_config=getattr(m, "config", None), + model_parallel_config=getattr(module, "config", None), alpha=self.alpha, is_expert=is_expert_linear(full_name), a2a_experimental=self.a2a_experimental, disable_sequence_parallel_comm=disable_sp_comm, ) - return LoRALinear(m, adapter) - return m + return LoRALinear(module, adapter) + return module class LoRAMerge(PEFT): @@ -171,7 +171,7 @@ class LoRAMerge(PEFT): """ @torch.no_grad() - def transform(self, m: nn.Module, name=None, prefix=None): + def transform(self, module: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None) -> nn.Module: """ Merges the LoRA adapter with the base model weights. @@ -184,16 +184,16 @@ def transform(self, m: nn.Module, name=None, prefix=None): nn.Module: The modified module with the LoRA adapter merged into the base model weights. """ - if not isinstance(m, LoRALinear): - return m + if not isinstance(module, LoRALinear): + return module logging.info(f"merging {(prefix if prefix else '') + '.' + (name if name else '')}") - base_weight = m.to_wrap.weight + base_weight = module.to_wrap.weight lora_weight = ( - m.adapter.alpha - / m.adapter.dim - * m.adapter.linear_out.weight.to(base_weight.device) - @ m.adapter.linear_in.weight.to(base_weight.device) + module.adapter.alpha + / module.adapter.dim + * module.adapter.linear_out.weight.to(base_weight.device) + @ module.adapter.linear_in.weight.to(base_weight.device) ) merged_weight = base_weight + lora_weight - m.to_wrap.weight.data = merged_weight - return m + module.to_wrap.weight.data = merged_weight + return module diff --git a/src/megatron/hub/peft/module_matcher.py b/src/megatron/hub/peft/module_matcher.py index e5997a42d1..2bbb82796b 100644 --- a/src/megatron/hub/peft/module_matcher.py +++ b/src/megatron/hub/peft/module_matcher.py @@ -14,7 +14,7 @@ from collections import defaultdict from dataclasses import dataclass, field -from typing import Dict, List, Set +from typing import Dict, List, Optional, Set from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear from torch import nn @@ -22,6 +22,7 @@ from megatron.hub.peft.utils import wildcard_match from megatron.hub.utils.import_utils import safe_import_from + TEColumnParallelLinear, HAVE_TE_COL_LINEAR = safe_import_from( "megatron.core.extensions.transformer_engine", "TEColumnParallelLinear" ) @@ -57,14 +58,14 @@ class ModuleMatcher: """ target_modules: List[str] = field( - default_factory=lambda: ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'] + default_factory=lambda: ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"] ) exclude_modules: List[str] = field(default_factory=list) canonical_mapping: Dict[str, Set] = field(default_factory=lambda: defaultdict(set)) def match( self, m: nn.Module, name: Optional[str] = None, prefix: Optional[str] = None - ) -> Optional[Tuple[str, str]]: + ) -> Optional[tuple[str, str]]: """ Determines whether a given module matches specified target patterns. @@ -130,4 +131,4 @@ def match( ): return (name, full_name) - return None \ No newline at end of file + return None From 4256a0b0752eca5928b0387399933465dbde5e47 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Thu, 12 Jun 2025 17:25:41 -0700 Subject: [PATCH 4/5] rebase Signed-off-by: Ananth Subramaniam --- src/megatron/hub/peft/module_matcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/megatron/hub/peft/module_matcher.py b/src/megatron/hub/peft/module_matcher.py index 2bbb82796b..8e527a7886 100644 --- a/src/megatron/hub/peft/module_matcher.py +++ b/src/megatron/hub/peft/module_matcher.py @@ -102,15 +102,15 @@ def match( """ Find the element in canonical_mapping which 1) matches the current `name` exactly, OR - 2) matches the current `full_name` with regex + 2) matches the current `full_name` with wildcard match is None if current module name doesn't match the specified targets. """ - assert len(self.exclude_modules) == 0 + assert len(self.exclude_modules) == 0, "exclude_modules should be empty when using canonical_mapping" for pattern in self.canonical_mapping: if name == pattern or wildcard_match(pattern, full_name): return (pattern, full_name) elif len(self.target_modules or []) > 0: - assert len(self.exclude_modules) == 0 + assert len(self.exclude_modules) == 0, "exclude_modules should be empty when using target_modules" for pattern in self.target_modules: if name == pattern or wildcard_match(pattern, full_name): return (pattern, full_name) @@ -125,7 +125,7 @@ def match( linear_types = tuple(linear_types) if ( - not name in self.exclude_modules + name not in self.exclude_modules and not any(wildcard_match(pattern, full_name) for pattern in self.exclude_modules) and isinstance(m, linear_types) ): From c9c5b74663b52dbd7d6e2b5598172135b898cad2 Mon Sep 17 00:00:00 2001 From: Ananth Subramaniam Date: Fri, 13 Jun 2025 09:16:46 -0700 Subject: [PATCH 5/5] update import Signed-off-by: Ananth Subramaniam --- tests/unit_tests/peft/test_lora_layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/peft/test_lora_layers.py b/tests/unit_tests/peft/test_lora_layers.py index 2fef1f11c6..ba986da9b8 100644 --- a/tests/unit_tests/peft/test_lora_layers.py +++ b/tests/unit_tests/peft/test_lora_layers.py @@ -31,7 +31,8 @@ # Test if Transformer Engine is available try: import transformer_engine.pytorch as te - from nemo_lm.peft.lora import TELinearAdapter + + from megatron.hub.peft.lora import TELinearAdapter HAVE_TE = True except ImportError: