Skip to content

Commit 9154566

Browse files
authored
[tests] model-level device_map clarifications (#11681)
* add clarity in documentation for device_map * docs * fix how compiler tester mixins are used. * propagate * more * typo. * fix tests * fix order of decroators. * clarify more. * more test cases. * fix doc * fix device_map docstring in pipeline_utils. * more examples * more * update * remove code for stuff that is already supported. * fix stuff.
1 parent b6f7933 commit 9154566

File tree

3 files changed

+74
-11
lines changed

3 files changed

+74
-11
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -814,14 +814,43 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
814814
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
815815
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
816816
information.
817-
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
817+
device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
818818
A map that specifies where each submodule should go. It doesn't need to be defined for each
819819
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
820820
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
821821
822+
Examples:
823+
824+
```py
825+
>>> from diffusers import AutoModel
826+
>>> import torch
827+
828+
>>> # This works.
829+
>>> model = AutoModel.from_pretrained(
830+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
831+
... )
832+
>>> # This also works (integer accelerator device ID).
833+
>>> model = AutoModel.from_pretrained(
834+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
835+
... )
836+
>>> # Specifying a supported offloading strategy like "auto" also works.
837+
>>> model = AutoModel.from_pretrained(
838+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
839+
... )
840+
>>> # Specifying a dictionary as `device_map` also works.
841+
>>> model = AutoModel.from_pretrained(
842+
... "stabilityai/stable-diffusion-xl-base-1.0",
843+
... subfolder="unet",
844+
... device_map={"": torch.device("cuda")},
845+
... )
846+
```
847+
822848
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
823849
more information about each option see [designing a device
824-
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
850+
map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
851+
can also refer to the [Diffusers-specific
852+
documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
853+
for more concrete examples.
825854
max_memory (`Dict`, *optional*):
826855
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
827856
each GPU and the available CPU RAM if unset.
@@ -1387,7 +1416,7 @@ def _load_pretrained_model(
13871416
low_cpu_mem_usage: bool = True,
13881417
dtype: Optional[Union[str, torch.dtype]] = None,
13891418
keep_in_fp32_modules: Optional[List[str]] = None,
1390-
device_map: Dict[str, Union[int, str, torch.device]] = None,
1419+
device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
13911420
offload_state_dict: Optional[bool] = None,
13921421
offload_folder: Optional[Union[str, os.PathLike]] = None,
13931422
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -669,14 +669,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
669669
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
670670
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
671671
information.
672-
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
673-
A map that specifies where each submodule should go. It doesn’t need to be defined for each
674-
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
675-
same device.
676-
677-
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
678-
more information about each option see [designing a device
679-
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
672+
device_map (`str`, *optional*):
673+
Strategy that dictates how the different components of a pipeline should be placed on available
674+
devices. Currently, only "balanced" `device_map` is supported. Check out
675+
[this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement)
676+
to know more.
680677
max_memory (`Dict`, *optional*):
681678
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
682679
each GPU and the available CPU RAM if unset.

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
require_peft_backend,
4747
require_torch_accelerator,
4848
require_torch_accelerator_with_fp16,
49+
require_torch_gpu,
4950
skip_mps,
5051
slow,
5152
torch_all_close,
@@ -1083,6 +1084,42 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
10831084
assert loaded_model
10841085
assert new_output.sample.shape == (4, 4, 16, 16)
10851086

1087+
@parameterized.expand(
1088+
[
1089+
(-1, "You can't pass device_map as a negative int"),
1090+
("foo", "When passing device_map as a string, the value needs to be a device name"),
1091+
]
1092+
)
1093+
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
1094+
with self.assertRaises(ValueError) as err_ctx:
1095+
_ = self.model_class.from_pretrained(
1096+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
1097+
)
1098+
1099+
assert msg_substring in str(err_ctx.exception)
1100+
1101+
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
1102+
@require_torch_gpu
1103+
def test_passing_non_dict_device_map_works(self, device_map):
1104+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1105+
loaded_model = self.model_class.from_pretrained(
1106+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
1107+
)
1108+
output = loaded_model(**inputs_dict)
1109+
assert output.sample.shape == (4, 4, 16, 16)
1110+
1111+
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
1112+
@require_torch_gpu
1113+
def test_passing_dict_device_map_works(self, name, device_map):
1114+
# There are other valid dict-based `device_map` values too. It's best to refer to
1115+
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
1116+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1117+
loaded_model = self.model_class.from_pretrained(
1118+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map}
1119+
)
1120+
output = loaded_model(**inputs_dict)
1121+
assert output.sample.shape == (4, 4, 16, 16)
1122+
10861123
@require_peft_backend
10871124
def test_load_attn_procs_raise_warning(self):
10881125
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)