Skip to content

Commit

Permalink
Merge pull request #135 from microsoft/req-updated
Browse files Browse the repository at this point in the history
update requirements
  • Loading branch information
sordonia authored Nov 12, 2024
2 parents a7ceca5 + 7d3cffb commit 98180c6
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 6 deletions.
1 change: 0 additions & 1 deletion mttl/models/modifiers/sm_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from scipy.sparse import csr_matrix
from torch import nn
from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut

from mttl.logging import logger
from mttl.models.modifiers.base import Modifier
Expand Down
1 change: 0 additions & 1 deletion mttl/models/modifiers/sparse_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from scipy.sparse import csr_matrix
from torch import nn
from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut

from mttl.logging import logger
from mttl.models.modifiers.base import Modifier, ModifierConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.nn.functional as F
import triton as tn
from spops import csr_add, sddmm
from triton.ops.blocksparse import matmul

from mttl.models.modifiers.sparse_mask import SparseMaskConfig, SparseWeights
from mttl.models.modifiers.sparse_utils.utils import init_sparse_weights
Expand Down
3 changes: 2 additions & 1 deletion mttl/models/modifiers/sparse_utils/sparse_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
from scipy.sparse import csr_matrix
from torch import nn
from triton.ops.blocksparse.matmul import dsd_lut, sdd_lut

from mttl.logging import logger
from mttl.models.modifiers.base import Modifier, ModifierConfig
Expand Down Expand Up @@ -430,6 +429,8 @@ def __init__(
parent_name=None,
sparse_func=None,
):
from triton.ops.blocksparse.matmul import sdd_lut

assert (
config.sps_type == "block_sparse"
), "BlockSparseLinearModule only supports block_sparse type"
Expand Down
7 changes: 6 additions & 1 deletion mttl/models/modifiers/sparse_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch.nn.functional as F
from scipy.sparse import csr_matrix
from triton.ops.blocksparse.matmul import _matmul

from mttl.logging import logger

Expand Down Expand Up @@ -515,6 +514,8 @@ def backward(ctx, grad_output):
c_lut,
block_size,
) = ctx.saved_tensors
from triton.ops.blocksparse.matmul import _matmul

weights = csr_add(
sparse_weights, row_offs, row_idx, col_idx, dense_weights
) # could be done also with torch.sparse.sampled_addmm
Expand All @@ -524,6 +525,7 @@ def backward(ctx, grad_output):
dX = grad_output @ weights
grad_output = grad_output.contiguous()
input = input.contiguous()

dsW = _matmul.fn["sdd"](
grad_output.unsqueeze(1),
input.unsqueeze(1),
Expand Down Expand Up @@ -582,12 +584,15 @@ def backward(ctx, grad_output):
c_lut,
block_size,
) = ctx.saved_tensors
from triton.ops.blocksparse.matmul import _matmul

weights = _scatter_add_flattened(dense_weights, sparse_weights, idxs)
block_size = block_size.item()
spdims = (1, weights.shape[0] // block_size, weights.shape[1] // block_size)
dX = grad_output @ weights
grad_output = grad_output.contiguous()
input = input.contiguous()

dsW = _matmul.fn["sdd"](
grad_output.unsqueeze(1),
input.unsqueeze(1),
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ datasets>=2.20.0
pytorch-lightning>=2.3.3
accelerate
deepspeed
huggingface_hub
huggingface_hub>=0.26.2
click
wandb
rouge
Expand All @@ -24,5 +24,6 @@ seaborn
azure-storage-blob
azure-identity
einops
triton
nltk
# spops @ git+https://github.com/IST-DASLab/spops.git@main
48 changes: 48 additions & 0 deletions tests/test_expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,54 @@ def test_from_pretrained_multi_selector(tmp_path):
assert model.selector_config.get("lora").__class__ == UniformSelectorConfig


def test_from_pretrained_with_arrow_save_and_reload(tmp_path):
# create a dummy library
model = MultiExpertModel(MultiExpertModelConfig("EleutherAI/gpt-neo-125m"))
model.add_empty_expert(
"a", LoRAConfig(modify_layers=".*out_proj", lora_init_b_random=True)
)
model.add_empty_expert(
"b", LoRAConfig(modify_layers=".*out_proj", lora_init_b_random=True)
)
library = model.save_to_library(f"local://{tmp_path}")

# store arrow experts
protos = ArrowTransform(ArrowConfig()).transform(library, persist=True)

# from pretrained library
selector_config = ArrowSelectorConfig(top_k=4)
model = MultiExpertModel.from_pretrained_library(
library, selector_config=selector_config
)
assert len(model.experts_names) == 2
# the order might be different due to multi-threading in adding experts in parallel
assert "a" in model.experts_names
assert "b" in model.experts_names

selector = model.selectors["lora"][0]
assert selector.config == selector_config
assert isinstance(selector, ArrowSelector)
model.save_pretrained(tmp_path)

# reload from the checkpoint
model = MultiExpertModel.from_pretrained(tmp_path)
selector = model.selectors["lora"][0]
assert selector.prototypes.shape[0] == 2
name1 = selector.expert_names[0]
name2 = selector.expert_names[1]
ln = selector.layer_name.replace(".selector", "")

assert np.allclose(
selector.prototypes[0].sum().item(),
protos[name1][ln].sum().item(),
)

assert np.allclose(
selector.prototypes[1].sum().item(),
protos[name2][ln].sum().item(),
)


def test_from_pretrained_with_arrow(tmp_path):
# create a dummy library
model = MultiExpertModel(MultiExpertModelConfig("EleutherAI/gpt-neo-125m"))
Expand Down

0 comments on commit 98180c6

Please sign in to comment.