Skip to content

Commit 686b523

Browse files
authored
fix: datasets broken import due to HF package and folder name collision (#1730)
This PR resolves the `datasets` import error that were introduced in #1712. The error causes the following error message: ```Traceback (most recent call last): File "./torchtitan/train.py", line 16, in <module> import torchtitan.protocols.train_spec as train_spec_module File "./torchtitan/__init__.py", line 12, in <module> import torchtitan.experiments # noqa: F401 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "./torchtitan/experiments/__init__.py", line 7, in <module> import torchtitan.experiments.llama4 # noqa: F401 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "./torchtitan/experiments/llama4/__init__.py", line 11, in <module> from torchtitan.datasets.hf_datasets import build_hf_dataloader File "./torchtitan/datasets/hf_datasets.py", line 12, in <module> from datasets import Dataset, load_dataset ImportError: cannot import name 'Dataset' from 'datasets' (./torchtitan/datasets/__init__.py) ``` Why is this happening? It is because #1712 added an `__init__.py` file to the `datasets` folder. On the surface the PR looks fine, however, it causes a collision with HF datasets Python package called `datasets`. So when we try to import the `Dataset` class, we actually want the HF datasets package and not the local `datasets` folder. This, in turn, causes an import error. The solution is simple, change the `__init__.py` name to something else, in our case I changed it to `common.py` as I found this name most fitting.
1 parent 476a965 commit 686b523

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

multinode_trainer.slurm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,5 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.t
5959
dcgmi profile --pause
6060
# adjust sbatch --ntasks and sbatch --nodes above and --nnodes below
6161
# to your specific node count, and update target launch file.
62-
srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./torchtitan/train.py --job.config_file ${CONFIG_FILE} "$@"
62+
srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" -m torchtitan.train --job.config_file ${CONFIG_FILE} "$@"
6363
dcgmi profile --resume

0 commit comments

Comments
 (0)