Skip to content

Commit 990d654

Browse files
committed
expose common dataloader args
1 parent cbdb311 commit 990d654

File tree

6 files changed

+121
-13
lines changed

6 files changed

+121
-13
lines changed

tests/unit_tests/test_dataset_flux.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ def test_load_dataset(self):
4545
for world_size in [2]:
4646
for rank in range(world_size):
4747
dataset_name = "cc12m-test-iterable"
48-
batch_size = 1
4948

49+
batch_size = 1
5050
num_steps = 15
51+
num_workers = 4
5152

5253
# TODO: if num_steps * batch_size * world_size is larger than the number of samples
5354
# in the dataset, then the test will fail, due to huggingface's
@@ -64,6 +65,8 @@ def test_load_dataset(self):
6465
dataset_name,
6566
"--training.local_batch_size",
6667
str(batch_size),
68+
"--training.dataloader.num_workers",
69+
str(num_workers),
6770
"--training.classifier_free_guidance_prob",
6871
"0.447",
6972
"--training.test_mode",
@@ -82,6 +85,8 @@ def test_load_dataset(self):
8285
infinite=True,
8386
)
8487

88+
assert dl.num_workers == num_workers
89+
8590
it = iter(dl)
8691

8792
for i in range(0, num_steps):

torchtitan/components/dataloader.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
5353
dp_rank: Data parallelism rank for this dataloader.
5454
dp_world_size: The world size of the data parallelism.
5555
batch_size: The batch size to use for each iteration.
56-
collate_fn: Optional function to collate samples in a batch.
56+
collate_fn (Callable, optional): A function that takes a list of samples from the
57+
dataset and collates them into a batch. Defaults to ``None``.
58+
num_workers: Number of worker processes for data loading. Defaults to 0.
59+
persistent_workers: If True, keep workers alive between dataset iterations.
60+
Only applicable when num_workers > 0. Defaults to False.
61+
prefetch_factor: Number of batches to prefetch per worker. Only applicable
62+
when num_workers > 0. Defaults to None (uses PyTorch default of 2).
63+
pin_memory: If True, copy tensors to CUDA pinned memory. Defaults to False.
5764
"""
5865

5966
dp_rank: int
@@ -67,11 +74,23 @@ def __init__(
6774
dp_world_size: int,
6875
batch_size: int,
6976
collate_fn: Callable | None = None,
77+
num_workers: int = 0,
78+
persistent_workers: bool = False,
79+
prefetch_factor: int | None = None,
80+
pin_memory: bool = False,
7081
):
7182
self.dp_world_size = dp_world_size
7283
self.dp_rank = dp_rank
7384
self.batch_size = batch_size
74-
super().__init__(dataset, batch_size, collate_fn=collate_fn)
85+
super().__init__(
86+
dataset,
87+
batch_size,
88+
collate_fn=collate_fn,
89+
num_workers=num_workers,
90+
persistent_workers=persistent_workers,
91+
prefetch_factor=prefetch_factor,
92+
pin_memory=pin_memory,
93+
)
7594
self._rank_id = f"dp_rank_{dp_rank}"
7695

7796
def state_dict(self) -> dict[str, Any]:

torchtitan/config/job_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,31 @@ class LRScheduler:
198198
"""
199199

200200

201+
@dataclass
202+
class DataLoader:
203+
"""
204+
Configuration for PyTorch DataLoader settings.
205+
"""
206+
207+
num_workers: int = 0
208+
"""Number of worker processes for data loading. 0 means data will be loaded in the main process."""
209+
210+
persistent_workers: bool = False
211+
"""
212+
If True, the data loader will not shutdown the worker processes after a dataset has been consumed once.
213+
This allows to maintain the workers Dataset instances alive. Only applicable when num_workers > 0.
214+
"""
215+
216+
prefetch_factor: int | None = None
217+
"""
218+
Number of batches loaded in advance by each worker. If None, the default value (2) is used.
219+
Only applicable when num_workers > 0.
220+
"""
221+
222+
pin_memory: bool = False
223+
"""If True, the data loader will copy Tensors into CUDA pinned memory before returning them."""
224+
225+
201226
@dataclass
202227
class Training:
203228
dataset: str = "c4_test"
@@ -263,6 +288,9 @@ class Training:
263288
many temporary files.
264289
"""
265290

291+
dataloader: DataLoader = field(default_factory=DataLoader)
292+
"""DataLoader configuration"""
293+
266294

267295
@dataclass
268296
class Parallelism:
@@ -908,6 +936,9 @@ class Validation:
908936
WARNING: When setting to -1 there could be hangs due to mismatch among ranks
909937
"""
910938

939+
dataloader: DataLoader = field(default_factory=DataLoader)
940+
"""DataLoader configuration"""
941+
911942
def __post_init__(self):
912943
assert (
913944
self.steps > 0 or self.steps == -1

torchtitan/experiments/vlm/datasets/mm_datasets.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,14 @@ def build_mm_dataloader(
381381
"""Build a data loader for multimodal datasets.
382382
383383
Args:
384-
dp_world_size: Data parallel world size
385-
dp_rank: Data parallel rank
386-
tokenizer: Tokenizer for text processing
387-
job_config: Job configuration
388-
infinite: Whether to loop infinitely
384+
dp_world_size: Data parallel world size.
385+
dp_rank: Data parallel rank.
386+
tokenizer: Tokenizer for text processing.
387+
job_config: Job configuration containing dataset and DataLoader settings.
388+
infinite: Whether to loop infinitely.
389389
390390
Returns:
391-
DataLoader with appropriate parallelism handling
391+
DataLoader with appropriate parallelism handling.
392392
"""
393393
dataset_path = job_config.training.dataset_path
394394
batch_size = job_config.training.local_batch_size
@@ -435,6 +435,10 @@ def build_mm_dataloader(
435435
dp_world_size=dp_world_size,
436436
batch_size=batch_size,
437437
collate_fn=collate_fn,
438+
num_workers=job_config.training.dataloader.num_workers,
439+
persistent_workers=job_config.training.dataloader.persistent_workers,
440+
prefetch_factor=job_config.training.dataloader.prefetch_factor,
441+
pin_memory=job_config.training.dataloader.pin_memory,
438442
)
439443

440444
return base_dataloader

torchtitan/hf_datasets/text_datasets.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,15 @@ def build_text_dataloader(
172172
job_config: JobConfig,
173173
infinite: bool = True,
174174
) -> ParallelAwareDataloader:
175-
"""Build a data loader for HuggingFace datasets."""
175+
"""Build a data loader for HuggingFace datasets.
176+
177+
Args:
178+
dp_world_size: Data parallelism world size.
179+
dp_rank: Data parallelism rank.
180+
tokenizer: Tokenizer to use for encoding text.
181+
job_config: Job configuration containing dataset and DataLoader settings.
182+
infinite: Whether to loop the dataset infinitely.
183+
"""
176184
dataset_name = job_config.training.dataset
177185
dataset_path = job_config.training.dataset_path
178186
batch_size = job_config.training.local_batch_size
@@ -193,6 +201,10 @@ def build_text_dataloader(
193201
dp_rank=dp_rank,
194202
dp_world_size=dp_world_size,
195203
batch_size=batch_size,
204+
num_workers=job_config.training.dataloader.num_workers,
205+
persistent_workers=job_config.training.dataloader.persistent_workers,
206+
prefetch_factor=job_config.training.dataloader.prefetch_factor,
207+
pin_memory=job_config.training.dataloader.pin_memory,
196208
)
197209

198210

@@ -203,7 +215,15 @@ def build_text_validation_dataloader(
203215
job_config: JobConfig,
204216
infinite: bool = False,
205217
) -> ParallelAwareDataloader:
206-
"""Build a validation data loader for HuggingFace datasets."""
218+
"""Build a validation data loader for HuggingFace datasets.
219+
220+
Args:
221+
dp_world_size: Data parallelism world size.
222+
dp_rank: Data parallelism rank.
223+
tokenizer: Tokenizer to use for encoding text.
224+
job_config: Job configuration containing dataset and DataLoader settings.
225+
infinite: Whether to loop the dataset infinitely.
226+
"""
207227
dataset_name = job_config.validation.dataset
208228
dataset_path = job_config.validation.dataset_path
209229
batch_size = job_config.validation.local_batch_size
@@ -224,4 +244,8 @@ def build_text_validation_dataloader(
224244
dp_rank=dp_rank,
225245
dp_world_size=dp_world_size,
226246
batch_size=batch_size,
247+
num_workers=job_config.validation.dataloader.num_workers,
248+
persistent_workers=job_config.validation.dataloader.persistent_workers,
249+
prefetch_factor=job_config.validation.dataloader.prefetch_factor,
250+
pin_memory=job_config.validation.dataloader.pin_memory,
227251
)

torchtitan/models/flux/flux_datasets.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,15 @@ def build_flux_dataloader(
314314
tokenizer: FluxTokenizer | None,
315315
infinite: bool = True,
316316
) -> ParallelAwareDataloader:
317-
"""Build a data loader for HuggingFace datasets."""
317+
"""Build a data loader for HuggingFace datasets.
318+
319+
Args:
320+
dp_world_size: Data parallelism world size.
321+
dp_rank: Data parallelism rank.
322+
job_config: Job configuration containing dataset and DataLoader settings.
323+
tokenizer: Tokenizer (kept for compatibility, not used).
324+
infinite: Whether to loop the dataset infinitely.
325+
"""
318326
dataset_name = job_config.training.dataset
319327
dataset_path = job_config.training.dataset_path
320328
batch_size = job_config.training.local_batch_size
@@ -337,6 +345,10 @@ def build_flux_dataloader(
337345
dp_rank=dp_rank,
338346
dp_world_size=dp_world_size,
339347
batch_size=batch_size,
348+
num_workers=job_config.training.dataloader.num_workers,
349+
persistent_workers=job_config.training.dataloader.persistent_workers,
350+
prefetch_factor=job_config.training.dataloader.prefetch_factor,
351+
pin_memory=job_config.training.dataloader.pin_memory,
340352
)
341353

342354

@@ -400,7 +412,16 @@ def build_flux_validation_dataloader(
400412
generate_timestamps: bool = True,
401413
infinite: bool = False,
402414
) -> ParallelAwareDataloader:
403-
"""Build a data loader for HuggingFace datasets."""
415+
"""Build a validation data loader for HuggingFace datasets.
416+
417+
Args:
418+
dp_world_size: Data parallelism world size.
419+
dp_rank: Data parallelism rank.
420+
job_config: Job configuration containing dataset and DataLoader settings.
421+
tokenizer: Tokenizer (kept for compatibility, not used).
422+
generate_timestamps: Whether to generate timesteps for validation.
423+
infinite: Whether to loop the dataset infinitely.
424+
"""
404425
dataset_name = job_config.validation.dataset
405426
dataset_path = job_config.validation.dataset_path
406427
batch_size = job_config.validation.local_batch_size
@@ -424,4 +445,8 @@ def build_flux_validation_dataloader(
424445
dp_rank=dp_rank,
425446
dp_world_size=dp_world_size,
426447
batch_size=batch_size,
448+
num_workers=job_config.validation.dataloader.num_workers,
449+
persistent_workers=job_config.validation.dataloader.persistent_workers,
450+
prefetch_factor=job_config.validation.dataloader.prefetch_factor,
451+
pin_memory=job_config.validation.dataloader.pin_memory,
427452
)

0 commit comments

Comments
 (0)