@@ -204,13 +204,18 @@ def check_and_set_device_map(device_map: "torch.device | int | str | dict | None
204204
205205
206206def compute_module_sizes (
207- model : "PreTrainedModel" , hf_quantizer : "HfQuantizer | None" = None , buffers_only : bool = False
207+ model : "PreTrainedModel" ,
208+ hf_quantizer : "HfQuantizer | None" = None ,
209+ buffers_only : bool = False ,
210+ only_modules : bool = True ,
208211) -> tuple [dict [str , int ], dict [str , int ]]:
209212 """
210213 Compute the size of each submodule of a given model (in bytes).
211214 Returns a tuple of 2 dicts, the fist one containing a mapping of all the modules and the corresponding size
212215 in bytes, and the 2nd one containing a mapping from all leaf modules (modules containing parameters, the end of
213216 the model graph) and the corresponding sizes.
217+ If `only_modules` is set to False, the first mapping will not only contain the size of all modules, but also
218+ the size of all parameters and buffers.
214219 """
215220 all_module_sizes = defaultdict (int )
216221 leaves_module_sizes = defaultdict (int )
@@ -241,6 +246,9 @@ def all_tensors():
241246 all_module_sizes ["." .join (name_parts [:idx ])] += size
242247 if "." in name :
243248 leaves_module_sizes [name .rsplit ("." , 1 )[0 ]] += size
249+ # If we want to also have the full leaves in `all_module_sizes`
250+ if not only_modules :
251+ all_module_sizes [name ] += size
244252
245253 return all_module_sizes , leaves_module_sizes
246254
@@ -542,7 +550,7 @@ def _init_infer_auto_device_map(
542550 else :
543551 main_devices = ["cpu" ]
544552
545- module_sizes , _ = compute_module_sizes (model , hf_quantizer )
553+ module_sizes , _ = compute_module_sizes (model , hf_quantizer , only_modules = False )
546554
547555 if tied_parameters is None :
548556 if len (model .all_tied_weights_keys ) > 0 :
0 commit comments