Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 7fe4554

Browse files
authored
[cherry-pick] 0.11.1 patch - global_sparsity backwards compatibility (#627)
1 parent 45e8c28 commit 7fe4554

File tree

5 files changed

+100
-13
lines changed

5 files changed

+100
-13
lines changed

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_magnitude.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
Modifiers classes related to magnitude pruning
1717
"""
1818

19+
import logging
1920
from typing import Dict, List, Union
2021

2122
import torch
2223
from torch import Tensor
2324
from torch.nn import Parameter
2425

25-
from sparseml.pytorch.optim.modifier import PyTorchModifierYAML
26+
from sparseml.pytorch.optim.modifier import ModifierProp, PyTorchModifierYAML
2627
from sparseml.pytorch.sparsification.pruning.mask_creator import (
2728
PruningMaskCreator,
2829
get_mask_creator_default,
@@ -43,6 +44,9 @@
4344
]
4445

4546

47+
_LOGGER = logging.getLogger(__name__)
48+
49+
4650
class MagnitudePruningParamsScorer(PruningParamsScorer):
4751
"""
4852
Scores parameters based on their magnitude
@@ -106,6 +110,9 @@ class GMPruningModifier(BaseGradualPruningModifier, BaseGMPruningModifier):
106110
:param mask_type: String to define type of sparsity to apply. May be 'unstructred'
107111
for unstructured pruning or 'block4' for four block pruning or a list of two
108112
integers for a custom block shape. Default is 'unstructured'
113+
:param global_sparsity: set True to use global magnitude pruning, False for
114+
layer-wise. Default is False. [DEPRECATED] - use GlobalMagnitudePruningModifier
115+
for global magnitude pruning and MagnitudePruningModifier for layer-wise
109116
"""
110117

111118
def __init__(
@@ -120,7 +127,10 @@ def __init__(
120127
inter_func: str = "cubic",
121128
log_types: Union[str, List[str]] = ALL_TOKEN,
122129
mask_type: str = "unstructured",
130+
global_sparsity: bool = False,
123131
):
132+
self._check_warn_global_sparsity(global_sparsity)
133+
124134
super(GMPruningModifier, self).__init__(
125135
params=params,
126136
init_sparsity=init_sparsity,
@@ -132,8 +142,8 @@ def __init__(
132142
log_types=log_types,
133143
mask_type=mask_type,
134144
leave_enabled=leave_enabled,
145+
global_sparsity=global_sparsity,
135146
end_comparator=-1,
136-
global_sparsity=self._use_global_sparsity,
137147
allow_reintroduction=False,
138148
parent_class_kwarg_names=[
139149
"init_sparsity",
@@ -161,10 +171,22 @@ def _get_scorer(self, params: List[Parameter]) -> PruningParamsScorer:
161171
"""
162172
return MagnitudePruningParamsScorer(params)
163173

164-
@property
165-
def _use_global_sparsity(self) -> bool:
166-
# base GMPruningModifier will not support global sparsity
167-
return False
174+
@ModifierProp()
175+
def global_sparsity(self) -> bool:
176+
"""
177+
:return: True for global magnitude pruning, False for
178+
layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
179+
for global magnitude pruning and MagnitudePruningModifier for layer-wise
180+
"""
181+
return self._global_sparsity
182+
183+
def _check_warn_global_sparsity(self, global_sparsity):
184+
if self.__class__.__name__ == "GMPruningModifier" and global_sparsity is True:
185+
_LOGGER.warning(
186+
"Use of global_sparsity parameter in GMPruningModifier is now "
187+
"deprecated. Use GlobalMagnitudePruningModifier instead for global "
188+
"magnitude pruning"
189+
)
168190

169191

170192
@PyTorchModifierYAML()
@@ -217,8 +239,41 @@ class MagnitudePruningModifier(GMPruningModifier):
217239
integers for a custom block shape. Default is 'unstructured'
218240
"""
219241

220-
# just an alias for GMPruningModifier
221-
pass
242+
def __init__(
243+
self,
244+
init_sparsity: Union[float, str],
245+
final_sparsity: Union[float, Dict[float, List[str]]],
246+
start_epoch: float,
247+
end_epoch: float,
248+
update_frequency: float,
249+
params: Union[str, List[str]],
250+
leave_enabled: bool = True,
251+
inter_func: str = "cubic",
252+
log_types: Union[str, List[str]] = ALL_TOKEN,
253+
mask_type: str = "unstructured",
254+
):
255+
super(MagnitudePruningModifier, self).__init__(
256+
params=params,
257+
init_sparsity=init_sparsity,
258+
final_sparsity=final_sparsity,
259+
start_epoch=start_epoch,
260+
end_epoch=end_epoch,
261+
update_frequency=update_frequency,
262+
inter_func=inter_func,
263+
log_types=log_types,
264+
mask_type=mask_type,
265+
leave_enabled=leave_enabled,
266+
global_sparsity=False,
267+
)
268+
269+
@ModifierProp(serializable=False)
270+
def global_sparsity(self) -> bool:
271+
"""
272+
:return: True for global magnitude pruning, False for
273+
layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
274+
for global magnitude pruning and MagnitudePruningModifier for layer-wise
275+
"""
276+
return self._global_sparsity
222277

223278

224279
@PyTorchModifierYAML()
@@ -295,8 +350,14 @@ def __init__(
295350
log_types=log_types,
296351
mask_type=mask_type,
297352
leave_enabled=leave_enabled,
353+
global_sparsity=True,
298354
)
299355

300-
@property
301-
def _use_global_sparsity(self) -> bool:
302-
return True
356+
@ModifierProp(serializable=False)
357+
def global_sparsity(self) -> bool:
358+
"""
359+
:return: True for global magnitude pruning, False for
360+
layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
361+
for global magnitude pruning and MagnitudePruningModifier for layer-wise
362+
"""
363+
return self._global_sparsity

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_movement.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torch import Tensor
2525
from torch.nn import Parameter
2626

27-
from sparseml.pytorch.optim.modifier import PyTorchModifierYAML
27+
from sparseml.pytorch.optim.modifier import ModifierProp, PyTorchModifierYAML
2828
from sparseml.pytorch.sparsification.pruning.modifier_pruning_magnitude import (
2929
GMPruningModifier,
3030
)
@@ -122,6 +122,15 @@ def _get_scorer(self, params: List[Parameter]) -> PruningParamsGradScorer:
122122
"""
123123
return MovementPruningParamsScorer(params=params)
124124

125+
@ModifierProp(serializable=False)
126+
def global_sparsity(self) -> bool:
127+
"""
128+
:return: True for global magnitude pruning, False for
129+
layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
130+
for global magnitude pruning and MagnitudePruningModifier for layer-wise
131+
"""
132+
return self._global_sparsity
133+
125134

126135
class MovementPruningParamsScorer(PruningParamsGradScorer):
127136
"""

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_structured.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,12 @@ def param_groups(self) -> List[List[str]]:
426426
useful for structures such as residual blocks or grouped convolutions
427427
"""
428428
return self._param_groups
429+
430+
@ModifierProp(serializable=False)
431+
def global_sparsity(self) -> bool:
432+
"""
433+
:return: True for global magnitude pruning, False for
434+
layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
435+
for global magnitude pruning and MagnitudePruningModifier for layer-wise
436+
"""
437+
return self._global_sparsity

src/sparseml/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datetime import date
2020

2121

22-
version_base = "0.11.0"
22+
version_base = "0.11.1"
2323
is_release = False # change to True to set the generated version as a release version
2424

2525

tests/sparseml/pytorch/sparsification/pruning/test_modifier_pruning_magnitude.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def test_gm_pruning_yaml(params, init_sparsity, final_sparsity):
249249
update_frequency = 1.0
250250
inter_func = "cubic"
251251
mask_type = "filter"
252+
global_sparsity = False
252253
yaml_str = f"""
253254
!GMPruningModifier
254255
init_sparsity: {init_sparsity}
@@ -259,6 +260,7 @@ def test_gm_pruning_yaml(params, init_sparsity, final_sparsity):
259260
params: {params}
260261
inter_func: {inter_func}
261262
mask_type: {mask_type}
263+
global_sparsity: {global_sparsity}
262264
"""
263265
yaml_modifier = GMPruningModifier.load_obj(yaml_str) # type: GMPruningModifier
264266
serialized_modifier = GMPruningModifier.load_obj(
@@ -273,12 +275,18 @@ def test_gm_pruning_yaml(params, init_sparsity, final_sparsity):
273275
params=params,
274276
inter_func=inter_func,
275277
mask_type=mask_type,
278+
global_sparsity=global_sparsity,
276279
)
277280

278281
assert isinstance(yaml_modifier, GMPruningModifier)
279282
pruning_modifier_serialization_vals_test(
280283
yaml_modifier, serialized_modifier, obj_modifier
281284
)
285+
assert (
286+
yaml_modifier.global_sparsity
287+
== serialized_modifier.global_sparsity
288+
== obj_modifier.global_sparsity
289+
)
282290

283291

284292
def test_magnitude_pruning_yaml():

0 commit comments

Comments
 (0)