Skip to content

Conversation

@divyanshk
Copy link

This diff introduces common dataloader args which are supported by statefuldataloader (and torch.utils.data dataloader). Users should be able to use them in their config files.

I was thinking about introducing a catch all kwargs to make it easier to specify args but that can easily complicate things (validation checks, duplication, existing defined named args in function definitions etc).

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 2, 2025
@divyanshk divyanshk force-pushed the divyanshk/dataloader_args branch from 6763cc0 to 990d654 Compare December 2, 2025 01:04
@divyanshk divyanshk marked this pull request as ready for review December 3, 2025 17:09
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I feel like I'm slightly lean towards using kwargs instead of adding these parameters one by one. This is because the StatefulDataLoader() has a lot of supported field and it's hard to say some of them are "common" in different use cases.

Can you explain more on "but that can easily complicate things"? We can just pass all the kwargs to StatefulDataLoader and let it to check correctness. wdyt @tianyu-l

Comment on lines 89 to 92
num_workers=num_workers,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
pin_memory=pin_memory,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about introducing a catch all kwargs to make it easier to specify args but that can easily complicate things (validation checks, duplication, existing defined named args in function definitions etc).

These are valid concerns. For now I'm leaning towards keeping things simple by passing **kwargs around.

Does it make sense if we only make these args explicit when sending to the actual init of StatefulDataLoader and not passing in all **kwargs from the input of ParallelAwareDataloader? The point is to not accidentally hit error inside StatefulDataLoader.

self,
dataset: IterableDataset,
dp_rank: int,
dp_world_size: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help change this: let's keep at most one positional arg (dataset) and others to be kwargs.

@divyanshk
Copy link
Author

Thanks @tianyu-l @wwwjn Updated the PR with kwargs based approach. I initially didn't do this to avoid any confusion on the user's part. That is because we provide batch_size, collate_fn (in mm_datasets) internally. I resolved that by making explicit args defined internally take precedence. Added a warning for users in config.py - so that should help. The error from wrong kwargs (if any) will be thrown in torchtitan itself - won't go down to StatefulDataloader.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general.

The CPU unit test in CI didn't run. Could you double check?

Also, please add an GPU integration test, see inline comments.

- batch_size: Determined by training.local_batch_size
- collate_fn: Set by the dataset-specific collator
Example (TOML config file):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a dedicated test for dataloader with kwargs passed through?
https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/features.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a GPU integration test. To be able to use the cli to pass in the kwargs I added a tyro rule. I am not super familiar with tyro so please have a look.

Also, shout out to the integration test setup. Love that we could do a quick mini-GPU run as part of feature testing.

OverrideDefinitions(
[
[
'--training.dataloader.kwargs \'{"num_workers": 2, "pin_memory": true, "prefetch_factor": 2}\'',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of letting cli accept a dict, can we just do

--training.dataloader.kwargs.num_workers 2 --training.dataloader.kwargs.pin_memory true, ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants