From 955180be1bd223c8babd6c1370f441ef27ad099b Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 11 Nov 2024 20:16:23 -0800 Subject: [PATCH 1/2] update requirements --- requirements.txt | 3 ++- tests/test_expert_model.py | 48 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c9b4fa7fc..4f462e05e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ datasets>=2.20.0 pytorch-lightning>=2.3.3 accelerate deepspeed -huggingface_hub +huggingface_hub>=0.26.2 click wandb rouge @@ -24,5 +24,6 @@ seaborn azure-storage-blob azure-identity einops +triton nltk # spops @ git+https://github.com/IST-DASLab/spops.git@main diff --git a/tests/test_expert_model.py b/tests/test_expert_model.py index cac3c2f3a..d55f4937a 100644 --- a/tests/test_expert_model.py +++ b/tests/test_expert_model.py @@ -172,6 +172,54 @@ def test_from_pretrained_multi_selector(tmp_path): assert model.selector_config.get("lora").__class__ == UniformSelectorConfig +def test_from_pretrained_with_arrow_save_and_reload(tmp_path): + # create a dummy library + model = MultiExpertModel(MultiExpertModelConfig("EleutherAI/gpt-neo-125m")) + model.add_empty_expert( + "a", LoRAConfig(modify_layers=".*out_proj", lora_init_b_random=True) + ) + model.add_empty_expert( + "b", LoRAConfig(modify_layers=".*out_proj", lora_init_b_random=True) + ) + library = model.save_to_library(f"local://{tmp_path}") + + # store arrow experts + protos = ArrowTransform(ArrowConfig()).transform(library, persist=True) + + # from pretrained library + selector_config = ArrowSelectorConfig(top_k=4) + model = MultiExpertModel.from_pretrained_library( + library, selector_config=selector_config + ) + assert len(model.experts_names) == 2 + # the order might be different due to multi-threading in adding experts in parallel + assert "a" in model.experts_names + assert "b" in model.experts_names + + selector = model.selectors["lora"][0] + assert selector.config == selector_config + assert isinstance(selector, ArrowSelector) + model.save_pretrained(tmp_path) + + # reload from the checkpoint + model = MultiExpertModel.from_pretrained(tmp_path) + selector = model.selectors["lora"][0] + assert selector.prototypes.shape[0] == 2 + name1 = selector.expert_names[0] + name2 = selector.expert_names[1] + ln = selector.layer_name.replace(".selector", "") + + assert np.allclose( + selector.prototypes[0].sum().item(), + protos[name1][ln].sum().item(), + ) + + assert np.allclose( + selector.prototypes[1].sum().item(), + protos[name2][ln].sum().item(), + ) + + def test_from_pretrained_with_arrow(tmp_path): # create a dummy library model = MultiExpertModel(MultiExpertModelConfig("EleutherAI/gpt-neo-125m")) From 7d3cffb6ffcb68a151e8303b061bef162a83bf23 Mon Sep 17 00:00:00 2001 From: Alessandro Sordoni Date: Mon, 11 Nov 2024 20:21:02 -0800 Subject: [PATCH 2/2] clean unused triton imports --- mttl/models/modifiers/sm_updater.py | 1 - mttl/models/modifiers/sparse_mask.py | 1 - .../modifiers/sparse_utils/csr_add_vs_scatter_add.py | 1 - mttl/models/modifiers/sparse_utils/sparse_linear.py | 3 ++- mttl/models/modifiers/sparse_utils/utils.py | 7 ++++++- 5 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mttl/models/modifiers/sm_updater.py b/mttl/models/modifiers/sm_updater.py index d476b6951..1b003b99a 100644 --- a/mttl/models/modifiers/sm_updater.py +++ b/mttl/models/modifiers/sm_updater.py @@ -7,7 +7,6 @@ import torch from scipy.sparse import csr_matrix from torch import nn -from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut from mttl.logging import logger from mttl.models.modifiers.base import Modifier diff --git a/mttl/models/modifiers/sparse_mask.py b/mttl/models/modifiers/sparse_mask.py index cca4973ce..6e527a40b 100644 --- a/mttl/models/modifiers/sparse_mask.py +++ b/mttl/models/modifiers/sparse_mask.py @@ -7,7 +7,6 @@ import torch from scipy.sparse import csr_matrix from torch import nn -from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig diff --git a/mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py b/mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py index 1f6888096..1de73c0be 100644 --- a/mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py +++ b/mttl/models/modifiers/sparse_utils/csr_add_vs_scatter_add.py @@ -5,7 +5,6 @@ import torch.nn.functional as F import triton as tn from spops import csr_add, sddmm -from triton.ops.blocksparse import matmul from mttl.models.modifiers.sparse_mask import SparseMaskConfig, SparseWeights from mttl.models.modifiers.sparse_utils.utils import init_sparse_weights diff --git a/mttl/models/modifiers/sparse_utils/sparse_linear.py b/mttl/models/modifiers/sparse_utils/sparse_linear.py index 4ce38854f..8aa5669bd 100644 --- a/mttl/models/modifiers/sparse_utils/sparse_linear.py +++ b/mttl/models/modifiers/sparse_utils/sparse_linear.py @@ -7,7 +7,6 @@ import torch from scipy.sparse import csr_matrix from torch import nn -from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut from mttl.logging import logger from mttl.models.modifiers.base import Modifier, ModifierConfig @@ -430,6 +429,8 @@ def __init__( parent_name=None, sparse_func=None, ): + from triton.ops.blocksparse.matmul import sdd_lut + assert ( config.sps_type == "block_sparse" ), "BlockSparseLinearModule only supports block_sparse type" diff --git a/mttl/models/modifiers/sparse_utils/utils.py b/mttl/models/modifiers/sparse_utils/utils.py index d8762bb7d..3ff0cb9ba 100644 --- a/mttl/models/modifiers/sparse_utils/utils.py +++ b/mttl/models/modifiers/sparse_utils/utils.py @@ -5,7 +5,6 @@ import torch import torch.nn.functional as F from scipy.sparse import csr_matrix -from triton.ops.blocksparse.matmul import _matmul from mttl.logging import logger @@ -515,6 +514,8 @@ def backward(ctx, grad_output): c_lut, block_size, ) = ctx.saved_tensors + from triton.ops.blocksparse.matmul import _matmul + weights = csr_add( sparse_weights, row_offs, row_idx, col_idx, dense_weights ) # could be done also with torch.sparse.sampled_addmm @@ -524,6 +525,7 @@ def backward(ctx, grad_output): dX = grad_output @ weights grad_output = grad_output.contiguous() input = input.contiguous() + dsW = _matmul.fn["sdd"]( grad_output.unsqueeze(1), input.unsqueeze(1), @@ -582,12 +584,15 @@ def backward(ctx, grad_output): c_lut, block_size, ) = ctx.saved_tensors + from triton.ops.blocksparse.matmul import _matmul + weights = _scatter_add_flattened(dense_weights, sparse_weights, idxs) block_size = block_size.item() spdims = (1, weights.shape[0] // block_size, weights.shape[1] // block_size) dX = grad_output @ weights grad_output = grad_output.contiguous() input = input.contiguous() + dsW = _matmul.fn["sdd"]( grad_output.unsqueeze(1), input.unsqueeze(1),