Skip to content

Commit 5aae81b

Browse files
authored
initial commit (#2223)
1 parent 88196d5 commit 5aae81b

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

src/sparseml/transformers/sparsification/modification/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
__all__ = ["check_transformers_version"]
2525

2626
_TRANSFORMERS_MIN_VERSION = "4.39.0"
27-
_TRANSFORMERS_MAX_VERSION = "4.39.2"
27+
_TRANSFORMERS_MAX_VERSION = "4.39.3"
2828

2929

3030
def check_transformers_version(
@@ -56,7 +56,7 @@ def check_transformers_version(
5656
_LOGGER.warning(
5757
"Attempting to modify the transformers model to support "
5858
"the SparseML-specific functionalities. However, the detected "
59-
f"transformers version ({current_version}) does not fall within the"
59+
f"transformers version ({current_version}) does not fall within the "
6060
f"supported version range ({min_version} - {max_version}). "
6161
"This may lead to unexpected behavior. Please ensure that the "
6262
"correct transformers version is installed."
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import shutil
16+
17+
import pytest
18+
19+
import sparseml.core.session as session_manager
20+
from huggingface_hub import snapshot_download
21+
from sparseml.transformers import SparseAutoModelForCausalLM
22+
23+
24+
@pytest.fixture
25+
def model_path(tmp_path):
26+
yield snapshot_download("stas/tiny-random-llama-2", local_dir=tmp_path)
27+
shutil.rmtree(tmp_path)
28+
29+
30+
@pytest.fixture
31+
def recipe():
32+
return """test_stage:
33+
obcq_modifiers:
34+
QuantizationModifier:
35+
ignore:
36+
- LlamaRotaryEmbedding
37+
- LlamaRMSNorm
38+
- {silu_activation}
39+
scheme_overrides:
40+
Embedding:
41+
input_activations: null
42+
weights:
43+
num_bits: 8
44+
symmetric: false"""
45+
46+
47+
def test_silu_alias_same_output(recipe, model_path):
48+
model_ = SparseAutoModelForCausalLM.from_pretrained(
49+
model_path, recipe=recipe.format(silu_activation="SiLU")
50+
)
51+
session_manager.create_session()
52+
session_manager.active_session().reset()
53+
model = SparseAutoModelForCausalLM.from_pretrained(
54+
model_path, recipe=recipe.format(silu_activation="SiLUActivation")
55+
)
56+
57+
dummy_input = model.dummy_inputs
58+
59+
out = model(**dummy_input)
60+
out_ = model_(**dummy_input)
61+
62+
out.logits.allclose(out_.logits)

0 commit comments

Comments
 (0)