@@ -314,7 +314,15 @@ def build_flux_dataloader(
314314 tokenizer : FluxTokenizer | None ,
315315 infinite : bool = True ,
316316) -> ParallelAwareDataloader :
317- """Build a data loader for HuggingFace datasets."""
317+ """Build a data loader for HuggingFace datasets.
318+
319+ Args:
320+ dp_world_size: Data parallelism world size.
321+ dp_rank: Data parallelism rank.
322+ job_config: Job configuration containing dataset and DataLoader settings.
323+ tokenizer: Tokenizer (kept for compatibility, not used).
324+ infinite: Whether to loop the dataset infinitely.
325+ """
318326 dataset_name = job_config .training .dataset
319327 dataset_path = job_config .training .dataset_path
320328 batch_size = job_config .training .local_batch_size
@@ -337,6 +345,10 @@ def build_flux_dataloader(
337345 dp_rank = dp_rank ,
338346 dp_world_size = dp_world_size ,
339347 batch_size = batch_size ,
348+ num_workers = job_config .training .dataloader .num_workers ,
349+ persistent_workers = job_config .training .dataloader .persistent_workers ,
350+ prefetch_factor = job_config .training .dataloader .prefetch_factor ,
351+ pin_memory = job_config .training .dataloader .pin_memory ,
340352 )
341353
342354
@@ -400,7 +412,16 @@ def build_flux_validation_dataloader(
400412 generate_timestamps : bool = True ,
401413 infinite : bool = False ,
402414) -> ParallelAwareDataloader :
403- """Build a data loader for HuggingFace datasets."""
415+ """Build a validation data loader for HuggingFace datasets.
416+
417+ Args:
418+ dp_world_size: Data parallelism world size.
419+ dp_rank: Data parallelism rank.
420+ job_config: Job configuration containing dataset and DataLoader settings.
421+ tokenizer: Tokenizer (kept for compatibility, not used).
422+ generate_timestamps: Whether to generate timesteps for validation.
423+ infinite: Whether to loop the dataset infinitely.
424+ """
404425 dataset_name = job_config .validation .dataset
405426 dataset_path = job_config .validation .dataset_path
406427 batch_size = job_config .validation .local_batch_size
@@ -424,4 +445,8 @@ def build_flux_validation_dataloader(
424445 dp_rank = dp_rank ,
425446 dp_world_size = dp_world_size ,
426447 batch_size = batch_size ,
448+ num_workers = job_config .validation .dataloader .num_workers ,
449+ persistent_workers = job_config .validation .dataloader .persistent_workers ,
450+ prefetch_factor = job_config .validation .dataloader .prefetch_factor ,
451+ pin_memory = job_config .validation .dataloader .pin_memory ,
427452 )
0 commit comments