Skip to content

Commit 6bc8121

Browse files
AmitMYSunMarcgithub-actions[bot]
authored
Fix Mac mps dataloader_num_workers > 1 causes RuntimeError: _share_filename_: only available on CPU (#38819)
* Update trainer.py: add multiprocessing_context for mps devices * Fix multiprocessing context for MPS with workers * Apply style fixes --------- Co-authored-by: Marc Sun <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 3eba206 commit 6bc8121

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/transformers/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,12 +1023,16 @@ def _get_dataloader(
10231023
else:
10241024
data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)
10251025

1026+
# MPS requrires forking if multiple workers are specified
1027+
should_fork = torch.backends.mps.is_available() and self.args.dataloader_num_workers > 1
1028+
10261029
dataloader_params = {
10271030
"batch_size": batch_size,
10281031
"collate_fn": data_collator,
10291032
"num_workers": self.args.dataloader_num_workers,
10301033
"pin_memory": self.args.dataloader_pin_memory,
10311034
"persistent_workers": self.args.dataloader_persistent_workers,
1035+
"multiprocessing_context": "fork" if should_fork else None,
10321036
}
10331037

10341038
if not isinstance(dataset, torch.utils.data.IterableDataset):

0 commit comments

Comments
 (0)