1616Modifiers classes related to magnitude pruning
1717"""
1818
19+ import logging
1920from typing import Dict , List , Union
2021
2122import torch
2223from torch import Tensor
2324from torch .nn import Parameter
2425
25- from sparseml .pytorch .optim .modifier import PyTorchModifierYAML
26+ from sparseml .pytorch .optim .modifier import ModifierProp , PyTorchModifierYAML
2627from sparseml .pytorch .sparsification .pruning .mask_creator import (
2728 PruningMaskCreator ,
2829 get_mask_creator_default ,
4344]
4445
4546
47+ _LOGGER = logging .getLogger (__name__ )
48+
49+
4650class MagnitudePruningParamsScorer (PruningParamsScorer ):
4751 """
4852 Scores parameters based on their magnitude
@@ -106,6 +110,9 @@ class GMPruningModifier(BaseGradualPruningModifier, BaseGMPruningModifier):
106110 :param mask_type: String to define type of sparsity to apply. May be 'unstructred'
107111 for unstructured pruning or 'block4' for four block pruning or a list of two
108112 integers for a custom block shape. Default is 'unstructured'
113+ :param global_sparsity: set True to use global magnitude pruning, False for
114+ layer-wise. Default is False. [DEPRECATED] - use GlobalMagnitudePruningModifier
115+ for global magnitude pruning and MagnitudePruningModifier for layer-wise
109116 """
110117
111118 def __init__ (
@@ -120,7 +127,10 @@ def __init__(
120127 inter_func : str = "cubic" ,
121128 log_types : Union [str , List [str ]] = ALL_TOKEN ,
122129 mask_type : str = "unstructured" ,
130+ global_sparsity : bool = False ,
123131 ):
132+ self ._check_warn_global_sparsity (global_sparsity )
133+
124134 super (GMPruningModifier , self ).__init__ (
125135 params = params ,
126136 init_sparsity = init_sparsity ,
@@ -132,8 +142,8 @@ def __init__(
132142 log_types = log_types ,
133143 mask_type = mask_type ,
134144 leave_enabled = leave_enabled ,
145+ global_sparsity = global_sparsity ,
135146 end_comparator = - 1 ,
136- global_sparsity = self ._use_global_sparsity ,
137147 allow_reintroduction = False ,
138148 parent_class_kwarg_names = [
139149 "init_sparsity" ,
@@ -161,10 +171,22 @@ def _get_scorer(self, params: List[Parameter]) -> PruningParamsScorer:
161171 """
162172 return MagnitudePruningParamsScorer (params )
163173
164- @property
165- def _use_global_sparsity (self ) -> bool :
166- # base GMPruningModifier will not support global sparsity
167- return False
174+ @ModifierProp ()
175+ def global_sparsity (self ) -> bool :
176+ """
177+ :return: True for global magnitude pruning, False for
178+ layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
179+ for global magnitude pruning and MagnitudePruningModifier for layer-wise
180+ """
181+ return self ._global_sparsity
182+
183+ def _check_warn_global_sparsity (self , global_sparsity ):
184+ if self .__class__ .__name__ == "GMPruningModifier" and global_sparsity is True :
185+ _LOGGER .warning (
186+ "Use of global_sparsity parameter in GMPruningModifier is now "
187+ "deprecated. Use GlobalMagnitudePruningModifier instead for global "
188+ "magnitude pruning"
189+ )
168190
169191
170192@PyTorchModifierYAML ()
@@ -217,8 +239,41 @@ class MagnitudePruningModifier(GMPruningModifier):
217239 integers for a custom block shape. Default is 'unstructured'
218240 """
219241
220- # just an alias for GMPruningModifier
221- pass
242+ def __init__ (
243+ self ,
244+ init_sparsity : Union [float , str ],
245+ final_sparsity : Union [float , Dict [float , List [str ]]],
246+ start_epoch : float ,
247+ end_epoch : float ,
248+ update_frequency : float ,
249+ params : Union [str , List [str ]],
250+ leave_enabled : bool = True ,
251+ inter_func : str = "cubic" ,
252+ log_types : Union [str , List [str ]] = ALL_TOKEN ,
253+ mask_type : str = "unstructured" ,
254+ ):
255+ super (MagnitudePruningModifier , self ).__init__ (
256+ params = params ,
257+ init_sparsity = init_sparsity ,
258+ final_sparsity = final_sparsity ,
259+ start_epoch = start_epoch ,
260+ end_epoch = end_epoch ,
261+ update_frequency = update_frequency ,
262+ inter_func = inter_func ,
263+ log_types = log_types ,
264+ mask_type = mask_type ,
265+ leave_enabled = leave_enabled ,
266+ global_sparsity = False ,
267+ )
268+
269+ @ModifierProp (serializable = False )
270+ def global_sparsity (self ) -> bool :
271+ """
272+ :return: True for global magnitude pruning, False for
273+ layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
274+ for global magnitude pruning and MagnitudePruningModifier for layer-wise
275+ """
276+ return self ._global_sparsity
222277
223278
224279@PyTorchModifierYAML ()
@@ -295,8 +350,14 @@ def __init__(
295350 log_types = log_types ,
296351 mask_type = mask_type ,
297352 leave_enabled = leave_enabled ,
353+ global_sparsity = True ,
298354 )
299355
300- @property
301- def _use_global_sparsity (self ) -> bool :
302- return True
356+ @ModifierProp (serializable = False )
357+ def global_sparsity (self ) -> bool :
358+ """
359+ :return: True for global magnitude pruning, False for
360+ layer-wise. [DEPRECATED] - use GlobalMagnitudePruningModifier
361+ for global magnitude pruning and MagnitudePruningModifier for layer-wise
362+ """
363+ return self ._global_sparsity
0 commit comments