Skip to content

Commit b5f9517

Browse files
committed
refactor 1
1 parent 309bc7c commit b5f9517

File tree

11 files changed

+235
-11
lines changed

11 files changed

+235
-11
lines changed

.ci/docker/common/install_conda.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ install_pip_dependencies() {
4141
# Install all Python dependencies
4242
pip_install -r /opt/conda/requirements-dev.txt
4343
pip_install -r /opt/conda/requirements.txt
44+
pip_install -r /opt/conda/requirements-flux.txt
4445
pip_install -r /opt/conda/requirements-vlm.txt
4546
popd
4647
}

.ci/docker/requirements-flux.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
transformers
2+
einops
3+
sentencepiece
4+
pillow

.ci/docker/requirements.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,3 @@ fsspec
88
tyro
99
tokenizers >= 0.15.0
1010
safetensors
11-
transformers
12-
einops
13-
sentencepiece
14-
pillow
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from functools import partial
8+
from typing import Any, Callable
9+
10+
import torch
11+
12+
from datasets import Dataset, load_dataset
13+
from datasets.distributed import split_dataset_by_node
14+
from torch.distributed.checkpoint.stateful import Stateful
15+
from torch.utils.data import IterableDataset
16+
17+
from torchtitan.components.dataloader import ParallelAwareDataloader
18+
from torchtitan.components.tokenizer import BaseTokenizer
19+
from torchtitan.config import JobConfig
20+
from torchtitan.hf_datasets import DatasetConfig
21+
from torchtitan.tools.logging import logger
22+
23+
24+
def _load_c4_dataset(dataset_path: str, split: str):
25+
"""Load C4 dataset with default configuration."""
26+
return load_dataset(dataset_path, name="en", split=split, streaming=True)
27+
28+
29+
def _process_c4_text(sample: dict[str, Any]) -> str:
30+
"""Process C4 dataset sample text."""
31+
return sample["text"]
32+
33+
34+
# Add your dataset here - more information at docs/datasets.md
35+
DATASETS = {
36+
"c4": DatasetConfig(
37+
path="allenai/c4",
38+
loader=partial(_load_c4_dataset, split="train"),
39+
sample_processor=_process_c4_text,
40+
),
41+
"c4_test": DatasetConfig(
42+
path="tests/assets/c4_test",
43+
loader=lambda path: load_dataset(path, split="train"),
44+
sample_processor=_process_c4_text,
45+
),
46+
"c4_validation": DatasetConfig(
47+
path="allenai/c4",
48+
loader=partial(_load_c4_dataset, split="validation"),
49+
sample_processor=_process_c4_text,
50+
),
51+
}
52+
53+
54+
def _validate_dataset(
55+
dataset_name: str, dataset_path: str | None = None
56+
) -> tuple[str, Callable, Callable]:
57+
"""Validate dataset name and path."""
58+
if dataset_name not in DATASETS:
59+
raise ValueError(
60+
f"Dataset {dataset_name} is not supported. "
61+
f"Supported datasets are: {list(DATASETS.keys())}"
62+
)
63+
64+
config = DATASETS[dataset_name]
65+
path = dataset_path or config.path
66+
logger.info(f"Preparing {dataset_name} dataset from {path}")
67+
return path, config.loader, config.sample_processor
68+
69+
70+
class HuggingFaceTextDataset(IterableDataset, Stateful):
71+
def __init__(
72+
self,
73+
dataset_name: str,
74+
dataset_path: str | None,
75+
tokenizer: BaseTokenizer,
76+
seq_len: int = 2048,
77+
dp_rank: int = 0,
78+
dp_world_size: int = 1,
79+
infinite: bool = False,
80+
) -> None:
81+
# Force lowercase for consistent comparison
82+
dataset_name = dataset_name.lower()
83+
84+
path, dataset_loader, text_processor = _validate_dataset(
85+
dataset_name, dataset_path
86+
)
87+
ds = dataset_loader(path)
88+
89+
self.dataset_name = dataset_name
90+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
91+
self._tokenizer = tokenizer
92+
self.seq_len = seq_len
93+
self.infinite = infinite
94+
self._text_processor = text_processor
95+
96+
# Variables for checkpointing
97+
self._sample_idx = 0
98+
self._token_buffer: list[int] = []
99+
100+
def _get_data_iter(self):
101+
# For map-style datasets, resume by skipping to the correct index
102+
# For iterable-style datasets, the underlying iterator already points to the correct index
103+
if isinstance(self._data, Dataset):
104+
if self._sample_idx == len(self._data):
105+
return iter([])
106+
else:
107+
return iter(self._data.skip(self._sample_idx))
108+
109+
return iter(self._data)
110+
111+
def __iter__(self):
112+
max_buffer_token_len = 1 + self.seq_len
113+
114+
while True:
115+
for sample in self._get_data_iter():
116+
# Use the dataset-specific text processor
117+
sample_text = self._text_processor(sample)
118+
sample_tokens = self._tokenizer.encode(
119+
sample_text, add_bos=True, add_eos=True
120+
)
121+
self._token_buffer.extend(sample_tokens)
122+
self._sample_idx += 1
123+
124+
while len(self._token_buffer) >= max_buffer_token_len:
125+
x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
126+
# update tokens to the remaining tokens
127+
self._token_buffer = self._token_buffer[max_buffer_token_len:]
128+
input = x[:-1]
129+
label = x[1:]
130+
yield {"input": input}, label
131+
132+
if not self.infinite:
133+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
134+
break
135+
else:
136+
# Reset offset for the next iteration
137+
self._sample_idx = 0
138+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
139+
# Ensures re-looping a dataset loaded from a checkpoint works correctly
140+
if not isinstance(self._data, Dataset):
141+
if hasattr(self._data, "set_epoch") and hasattr(
142+
self._data, "epoch"
143+
):
144+
self._data.set_epoch(self._data.epoch + 1)
145+
146+
def load_state_dict(self, state_dict):
147+
self._token_buffer = state_dict["token_buffer"]
148+
149+
if isinstance(self._data, Dataset):
150+
self._sample_idx = state_dict["sample_idx"]
151+
else:
152+
assert "data" in state_dict
153+
self._data.load_state_dict(state_dict["data"])
154+
155+
def state_dict(self):
156+
_state_dict = {"token_buffer": self._token_buffer}
157+
158+
if isinstance(self._data, Dataset):
159+
_state_dict["sample_idx"] = self._sample_idx
160+
else:
161+
# Save the iterable dataset's state to later efficiently resume from it
162+
# https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration
163+
_state_dict["data"] = self._data.state_dict()
164+
165+
return _state_dict
166+
167+
168+
def build_text_dataloader(
169+
dp_world_size: int,
170+
dp_rank: int,
171+
tokenizer: BaseTokenizer,
172+
job_config: JobConfig,
173+
infinite: bool = True,
174+
) -> ParallelAwareDataloader:
175+
"""Build a data loader for HuggingFace datasets."""
176+
dataset_name = job_config.training.dataset
177+
dataset_path = job_config.training.dataset_path
178+
batch_size = job_config.training.local_batch_size
179+
seq_len = job_config.training.seq_len
180+
181+
hf_ds = HuggingFaceTextDataset(
182+
dataset_name=dataset_name,
183+
dataset_path=dataset_path,
184+
tokenizer=tokenizer,
185+
seq_len=seq_len,
186+
dp_rank=dp_rank,
187+
dp_world_size=dp_world_size,
188+
infinite=infinite,
189+
)
190+
191+
return ParallelAwareDataloader(
192+
dataset=hf_ds,
193+
dp_rank=dp_rank,
194+
dp_world_size=dp_world_size,
195+
batch_size=batch_size,
196+
)
197+
198+
199+
def build_text_validation_dataloader(
200+
dp_world_size: int,
201+
dp_rank: int,
202+
tokenizer: BaseTokenizer,
203+
job_config: JobConfig,
204+
infinite: bool = False,
205+
) -> ParallelAwareDataloader:
206+
"""Build a validation data loader for HuggingFace datasets."""
207+
dataset_name = job_config.validation.dataset
208+
dataset_path = job_config.validation.dataset_path
209+
batch_size = job_config.validation.local_batch_size
210+
seq_len = job_config.validation.seq_len
211+
212+
hf_ds = HuggingFaceTextDataset(
213+
dataset_name=dataset_name,
214+
dataset_path=dataset_path,
215+
tokenizer=tokenizer,
216+
seq_len=seq_len,
217+
dp_rank=dp_rank,
218+
dp_world_size=dp_world_size,
219+
infinite=infinite,
220+
)
221+
222+
return ParallelAwareDataloader(
223+
dataset=hf_ds,
224+
dp_rank=dp_rank,
225+
dp_world_size=dp_world_size,
226+
batch_size=batch_size,
227+
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.ci/docker/requirements-flux.txt

scripts/flux_inference/run_infer.sh renamed to torchtitan/models/flux/run_infer.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.t
1717
PYTORCH_ALLOC_CONF="expandable_segments:True" \
1818
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
1919
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
20-
-m scripts.flux_inference.infer --job.config_file ${CONFIG_FILE} \
20+
-m torchtitan.models.flux.inference.infer --job.config_file ${CONFIG_FILE} \
2121
--checkpoint.enable \
2222
--checkpoint.exclude_from_loading=lr_scheduler,dataloader,optimizer "$@"

torchtitan/models/flux/tests/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)