Skip to content

Commit 59dfc1d

Browse files
authored
Fix device_map computation part 2 (#42290)
fix
1 parent 4391cfd commit 59dfc1d

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/transformers/integrations/accelerate.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,18 @@ def check_and_set_device_map(device_map: "torch.device | int | str | dict | None
204204

205205

206206
def compute_module_sizes(
207-
model: "PreTrainedModel", hf_quantizer: "HfQuantizer | None" = None, buffers_only: bool = False
207+
model: "PreTrainedModel",
208+
hf_quantizer: "HfQuantizer | None" = None,
209+
buffers_only: bool = False,
210+
only_modules: bool = True,
208211
) -> tuple[dict[str, int], dict[str, int]]:
209212
"""
210213
Compute the size of each submodule of a given model (in bytes).
211214
Returns a tuple of 2 dicts, the fist one containing a mapping of all the modules and the corresponding size
212215
in bytes, and the 2nd one containing a mapping from all leaf modules (modules containing parameters, the end of
213216
the model graph) and the corresponding sizes.
217+
If `only_modules` is set to False, the first mapping will not only contain the size of all modules, but also
218+
the size of all parameters and buffers.
214219
"""
215220
all_module_sizes = defaultdict(int)
216221
leaves_module_sizes = defaultdict(int)
@@ -241,6 +246,9 @@ def all_tensors():
241246
all_module_sizes[".".join(name_parts[:idx])] += size
242247
if "." in name:
243248
leaves_module_sizes[name.rsplit(".", 1)[0]] += size
249+
# If we want to also have the full leaves in `all_module_sizes`
250+
if not only_modules:
251+
all_module_sizes[name] += size
244252

245253
return all_module_sizes, leaves_module_sizes
246254

@@ -542,7 +550,7 @@ def _init_infer_auto_device_map(
542550
else:
543551
main_devices = ["cpu"]
544552

545-
module_sizes, _ = compute_module_sizes(model, hf_quantizer)
553+
module_sizes, _ = compute_module_sizes(model, hf_quantizer, only_modules=False)
546554

547555
if tied_parameters is None:
548556
if len(model.all_tied_weights_keys) > 0:

0 commit comments

Comments
 (0)