From a8b5ec6c53f6eb026c917ac72f9e58f0757b73dd Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 18 Oct 2024 13:50:17 -0700 Subject: [PATCH 01/57] WIP: Generate a mixture dataset --- src/olmo_core/data/mixture_dataset.py | 230 ++++++++++++++++ src/test/data/mixture_dataset_test.py | 376 ++++++++++++++++++++++++++ 2 files changed, 606 insertions(+) create mode 100644 src/olmo_core/data/mixture_dataset.py create mode 100644 src/test/data/mixture_dataset_test.py diff --git a/src/olmo_core/data/mixture_dataset.py b/src/olmo_core/data/mixture_dataset.py new file mode 100644 index 00000000..73481721 --- /dev/null +++ b/src/olmo_core/data/mixture_dataset.py @@ -0,0 +1,230 @@ +import math +import os +import random +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List + +import smart_open +import tabulate + +from olmo_core.aliases import PathOrStr +from olmo_core.config import Config +from olmo_core.data import NumpyDatasetDType +from olmo_core.data.utils import load_array_slice, memmap_to_write +from olmo_core.exceptions import OLMoConfigurationError +from olmo_core.io import get_bytes_range, get_file_size + + +@dataclass +class SourceMixtureConfig(Config): + source_name: str + target_ratio: float + paths: List[PathOrStr] + # 1.0 will result in a maximum of 1 repitition of the source data per epoch + max_repetion_ratio: float = 1.0 + max_source_fraction: float = 1.0 + + def validate(self): + if self.target_ratio: + if not 0 <= self.target_ratio <= 1: + raise OLMoConfigurationError("target_ratio must be in the range [0, 1]") + if not 0 <= self.max_source_fraction <= 1: + raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]") + if self.max_source_fraction < self.target_ratio: + raise OLMoConfigurationError("max_source_fraction must be >= target_ratio") + + if not self.paths: + raise OLMoConfigurationError("paths must not be empty") + + if not 0 <= self.max_source_fraction <= 1: + raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]") + + +@dataclass +class SourceTokenDetails: + """ + A class to hold intermediate selection details for a mixture source. + """ + + source: SourceMixtureConfig + source_population: int + num_selected: int + + def for_table(self, max_tokens: int) -> Dict: + return { + "source_name": self.source.source_name, + "source_population": f"{self.source_population:.2e}", + "num_sampled": f"{self.num_selected:.2e}", + "target_ratio": self.source.target_ratio, + "max_repetion_ratio": self.source.max_repetion_ratio, + "max_source_fraction": self.source.max_source_fraction, + "observed_source_ratio": self.num_selected / self.source_population, + "observed_global_ratio": self.num_selected / max_tokens, + } + + +@dataclass +class SourceMixture: + """ + A fractionalized mixture of source tokens. + """ + + source_name: str + paths: List[str] + + +@dataclass +class SourceMixtureDataset: + """ + A dataset consisting of a fractionalized mixture of data sources. + """ + + sources: List[SourceMixture] + + +@dataclass +class SourceMixtureDatasetConfig(Config): + """ + A configuration class for building a dataset from a fractionalized mixture of sources. + """ + + max_tokens: int + source_configs: List[SourceMixtureConfig] + dtype: NumpyDatasetDType + output_dir: PathOrStr + processes: int = 1 + seed: int = 42 + + def __post_init__(self): + os.makedirs(self.output_dir, exist_ok=True) + + def validate(self): + if (total := sum([source.target_ratio for source in self.source_configs])) != 1.0: + raise OLMoConfigurationError(f"target_ratios must sum to 1, got {total}") + + def build(self) -> SourceMixtureDataset: + self.validate() + random.seed(self.seed) + available_tokens_by_source: Dict[str, int] = {} + + # Count the number of tokens available for each source + for source_config in self.source_configs: + available_tokens_by_source[source_config.source_name] = self._count_tokens_for_source( + source_config, self.dtype + ) + + tokens_outcome_per_source: List[SourceTokenDetails] = [] + + # Calculate the number of tokens to include for each source + for source_config in self.source_configs: + available_for_source = available_tokens_by_source[source_config.source_name] + target_for_source = int(self.max_tokens * source_config.target_ratio) + max_for_source = int( + available_for_source + * source_config.max_source_fraction + * source_config.max_repetion_ratio + ) + + # Ensure that the available source tokens meet the target ratio requirement + if not max_for_source >= target_for_source: + raise OLMoConfigurationError( + f"Insufficient tokens for source: {source_config.source_name} @ target global ratio: {source_config.target_ratio} :: {max_for_source} < {target_for_source}" + ) + + tokens_outcome_per_source.append( + SourceTokenDetails( + source=source_config, + source_population=available_for_source, + num_selected=target_for_source, + ) + ) + + completed = [] + for outcome in tokens_outcome_per_source: + completed.append(self._handle_source_outcome(outcome)) + + print("Mixing outcome by source:") + print( + tabulate.tabulate( + [item.for_table(self.max_tokens) for item in tokens_outcome_per_source], + headers="keys", + tablefmt="pretty", + ), + ) + + return SourceMixtureDataset(completed) + + def _handle_source_outcome(self, outcome: SourceTokenDetails) -> SourceMixture: + """ + Write selected tokens for a source to a local file and return the path. + """ + return SourceMixture( + source_name=outcome.source.source_name, + paths=self._write_tokens_for_source(self.dtype, outcome.num_selected, outcome.source), + ) + + def _write_tokens_for_source( + self, dtype: NumpyDatasetDType, tokens_to_take: int, source_config: SourceMixtureConfig + ) -> List[str]: + """ + Stream selected tokens into a local file based on selection criteria. + """ + # Shuffle the paths to avoid biasing our selection to sequential file paths + paths = source_config.paths.copy() + random.shuffle(paths) + tokens_taken = 0 + written: List[str] = [] + + # Make sure we have enough paths to accommodate repetitions + for idx, path in enumerate(paths * math.ceil(source_config.max_repetion_ratio)): + # Stop if we've taken enough tokens + if tokens_taken >= tokens_to_take: + break + + filename = f"{self.output_dir}/{idx:05}_{source_config.source_name}.npy" + nda = load_array_slice(path, 0, tokens_to_take - tokens_taken, dtype.as_np_dtype()) + with memmap_to_write( + path=Path(filename), shape=(len(nda),), dtype=dtype.as_np_dtype() + ) as mm: + mm[:] = nda + + written.append(filename) + tokens_taken += tokens_to_take - tokens_taken + + return written + + def _count_tokens_for_source( + self, source_config: SourceMixtureConfig, dtype: NumpyDatasetDType + ) -> int: + """ + Count the number of tokens for a set of source token files in parallel. + + Args: + source_config: The source configuration. + dtype: The data type of the source tokens. + """ + + def _count_tokens(path) -> int: + size = get_file_size(path) + tokens = self._bytes_to_tokens(size, dtype) + return tokens + + with ThreadPoolExecutor(max_workers=self.processes) as executor: + return sum(executor.map(_count_tokens, source_config.paths)) + + def _bytes_to_tokens(self, num_bytes: int, dtype: NumpyDatasetDType) -> int: + """ + Convert bytes to tokens based on the dtype. + """ + npdtype = dtype.as_np_dtype() + return num_bytes // npdtype(int(0)).itemsize + + def _tokens_to_bytes(self, num_tokens: int, dtype: NumpyDatasetDType) -> int: + """ + Convert tokens to bytes based on the dtype. + """ + + npdtype = dtype.as_np_dtype() + return num_tokens * npdtype(int(0)).itemsize diff --git a/src/test/data/mixture_dataset_test.py b/src/test/data/mixture_dataset_test.py new file mode 100644 index 00000000..ac371dfb --- /dev/null +++ b/src/test/data/mixture_dataset_test.py @@ -0,0 +1,376 @@ +from itertools import chain +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import List + +import numpy as np +import pytest + +from olmo_core.aliases import PathOrStr +from olmo_core.data import NumpyDatasetDType +from olmo_core.data.mixture_dataset import ( + SourceMixtureConfig, + SourceMixtureDataset, + SourceMixtureDatasetConfig, +) +from olmo_core.data.utils import load_array_slice +from olmo_core.exceptions import OLMoConfigurationError + +DATA = { + "dtype": NumpyDatasetDType.uint32, + "itemsize": np.dtype(np.uint32).itemsize, + "tokens_per_file": 1_000_000, +} + + +def _make_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> List[PathOrStr]: + mmaps = [] + for i in range(num_files): + filepath = f"{tmp_path}/{prefix}_{i}.npy" + data = np.random.randint(0, 2**32, size=size, dtype=np.uint32) + mm = np.memmap(filepath, mode="w+", dtype=DATA["dtype"].as_np_dtype(), shape=(size,)) + mm[:] = data + mm.flush() + mmaps.append(Path(filepath)) + + return mmaps + + +def test_source_mixture_config_validation(): + with pytest.raises(OLMoConfigurationError): + SourceMixtureConfig( + source_name="source1", target_ratio=1.2, paths=["/path/to/source1"] + ).validate() + + with pytest.raises(OLMoConfigurationError): + SourceMixtureConfig( + source_name="source1", + target_ratio=0.5, + max_source_fraction=0.4, + paths=["/path/to/source1"], + ).validate() + + with pytest.raises(OLMoConfigurationError): + SourceMixtureConfig(source_name="source1", target_ratio=0.5, paths=[]).validate() + + config = SourceMixtureConfig( + source_name="source1", target_ratio=0.5, paths=["/path/to/source1"] + ) + config.validate() + + +def test_dataset_mixture_config_validation(): + source_configs = [ + SourceMixtureConfig(source_name="source1", target_ratio=0.5, paths=["/path/to/source1"]), + SourceMixtureConfig(source_name="source2", target_ratio=0.5, paths=["/path/to/source2"]), + ] + + with TemporaryDirectory() as tmp_dir: + config = SourceMixtureDatasetConfig( + max_tokens=1000, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + ) + config.validate() + + source_configs_invalid = [ + SourceMixtureConfig( + source_name="source1", target_ratio=0.7, paths=["/path/to/source1"] + ), + SourceMixtureConfig( + source_name="source2", target_ratio=0.5, paths=["/path/to/source2"] + ), + ] + + config_invalid = SourceMixtureDatasetConfig( + max_tokens=1000, + source_configs=source_configs_invalid, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + ) + + with pytest.raises(OLMoConfigurationError): + config_invalid.validate() + + +def test_dataset_mixture_build(tmp_path: Path): + source_paths = { + "1": _make_mmaps( + tmp_path=tmp_path, prefix="source1", num_files=2, size=DATA["tokens_per_file"] + ), + "2": _make_mmaps( + tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] + ), + "3": _make_mmaps( + tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] + ), + } + + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.33, + paths=source_paths["1"], + ), + SourceMixtureConfig(source_name="2", target_ratio=0.33, paths=source_paths["2"]), + SourceMixtureConfig( + source_name="3", + target_ratio=0.34, + paths=source_paths["3"], + ), + ] + + max_tokens = 5_000_000 + + with TemporaryDirectory() as tmp_dir: + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + ) + + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) + + +def test_dataset_mixture_build_insufficient_source_data(tmp_path: Path): + source_paths = { + "1": _make_mmaps( + tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] + ), + "2": _make_mmaps( + tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] + ), + "3": _make_mmaps( + tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] + ), + } + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.5, + paths=source_paths["1"], + ), + SourceMixtureConfig(source_name="2", target_ratio=0.25, paths=source_paths["2"]), + SourceMixtureConfig( + source_name="3", + target_ratio=0.25, + paths=source_paths["3"], + ), + ] + + max_tokens = 5_000_000 + + with TemporaryDirectory() as tmp_dir: + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + ) + + # Should raise exception because the target ratio for source 1 @50% (2.5M) is infeasible without repetition + with pytest.raises(OLMoConfigurationError): + config.build() + + +def test_dataset_mixture_build_with_repetition(tmp_path: Path): + """ + Test building a dataset with repetition of a source. + + Source 1 has a target ratio of 90% and a max repetition ratio of 4.0, so it should be possible to meet the target of 3600 tokens with 1 file of 1000 tokens repeated 4 times. + """ + source_paths = { + "1": _make_mmaps( + tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] + ), + "2": _make_mmaps( + tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] + ), + "3": _make_mmaps( + tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] + ), + } + + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.5, + max_repetion_ratio=3.0, # Allow 3x repetition of source1 so that we can meet the target of 2.5M + paths=source_paths["1"], + ), + SourceMixtureConfig(source_name="2", target_ratio=0.25, paths=source_paths["2"]), + SourceMixtureConfig( + source_name="3", + target_ratio=0.25, + paths=source_paths["3"], + ), + ] + + max_tokens = 5_000_000 + + with TemporaryDirectory() as tmp_dir: + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + ) + + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) + + +def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): + source_paths = { + "1": _make_mmaps( + tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] + ), + "2": _make_mmaps( + tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] + ), + "3": _make_mmaps( + tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] + ), + } + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.25, + paths=source_paths["1"], + max_source_fraction=0.10, # Allow only 10% of source1 to be used (population is 1M tokens) + ), + SourceMixtureConfig( + source_name="2", + target_ratio=0.25, + paths=source_paths["2"], + ), + SourceMixtureConfig( + source_name="3", + target_ratio=0.5, + paths=source_paths["3"], + ), + ] + + # 5 source files * 1_000_000 tokens per file + max_tokens = len(list(chain(*source_paths.values()))) * DATA["tokens_per_file"] + + with TemporaryDirectory() as tmp_dir: + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + ) + + # Should raise exception because the target ratio for source 1 is infeasible because + # we limit usage to 10% of the source + with pytest.raises(OLMoConfigurationError): + config.build() + + +def test_dataset_mixture_build_expected_files(tmp_path: Path): + source_paths = { + "1": _make_mmaps( + tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] + ), + "2": _make_mmaps( + tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] + ), + "3": _make_mmaps( + tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] + ), + } + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.10, + paths=source_paths["1"], + ), + SourceMixtureConfig( + source_name="2", + target_ratio=0.40, + paths=source_paths["2"], + ), + SourceMixtureConfig( + source_name="3", + target_ratio=0.5, + paths=source_paths["3"], + ), + ] + + max_tokens = 10 * 1000 + + with TemporaryDirectory() as tmp_dir: + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + ) + + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) + + out_tokens = [] + + for source in mixture.sources: + for path in source.paths: + out_tokens.extend( + load_array_slice( + path=path, + start_idx=0, + end_idx=DATA["tokens_per_file"], + dtype=DATA["dtype"].as_np_dtype(), + ) + ) + + assert len(out_tokens) == max_tokens + + +def test_dataset_mixture_render_table(tmp_path: Path, capsys): + source_paths = { + "1": _make_mmaps( + tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] + ), + "2": _make_mmaps( + tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] + ), + "3": _make_mmaps( + tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] + ), + } + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.10, + paths=source_paths["1"], + ), + SourceMixtureConfig( + source_name="2", + target_ratio=0.40, + paths=source_paths["2"], + ), + SourceMixtureConfig( + source_name="3", + target_ratio=0.5, + paths=source_paths["3"], + ), + ] + + max_tokens = 10 * 1000 + + with TemporaryDirectory() as tmp_dir: + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + ) + + with capsys.disabled(): + print("\n") + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) From 637fee9afc39c78ba5cbf71d078c9102d2516071 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 18 Oct 2024 14:06:44 -0700 Subject: [PATCH 02/57] WIP: Adds dry run --- src/olmo_core/data/mixture_dataset.py | 14 +++++++-- src/test/data/mixture_dataset_test.py | 44 +++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/data/mixture_dataset.py b/src/olmo_core/data/mixture_dataset.py index 73481721..43666ec3 100644 --- a/src/olmo_core/data/mixture_dataset.py +++ b/src/olmo_core/data/mixture_dataset.py @@ -96,11 +96,18 @@ class SourceMixtureDatasetConfig(Config): output_dir: PathOrStr processes: int = 1 seed: int = 42 + dry_run: bool = False def __post_init__(self): os.makedirs(self.output_dir, exist_ok=True) def validate(self): + if self.max_tokens <= 0: + raise OLMoConfigurationError("max_tokens must be > 0") + + if not self.source_configs: + raise OLMoConfigurationError("source_configs must not be empty") + if (total := sum([source.target_ratio for source in self.source_configs])) != 1.0: raise OLMoConfigurationError(f"target_ratios must sum to 1, got {total}") @@ -142,10 +149,11 @@ def build(self) -> SourceMixtureDataset: ) completed = [] - for outcome in tokens_outcome_per_source: - completed.append(self._handle_source_outcome(outcome)) + if not self.dry_run: + for outcome in tokens_outcome_per_source: + completed.append(self._handle_source_outcome(outcome)) - print("Mixing outcome by source:") + print(f"Mixing outcome by source: {'' if not self.dry_run else '(DRY RUN)'}") print( tabulate.tabulate( [item.for_table(self.max_tokens) for item in tokens_outcome_per_source], diff --git a/src/test/data/mixture_dataset_test.py b/src/test/data/mixture_dataset_test.py index ac371dfb..611f3877 100644 --- a/src/test/data/mixture_dataset_test.py +++ b/src/test/data/mixture_dataset_test.py @@ -36,6 +36,50 @@ def _make_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> List[ return mmaps +def test_source_mixture_config_dry_run(tmp_path: Path, capsys): + source_paths = { + "1": _make_mmaps( + tmp_path=tmp_path, prefix="source1", num_files=2, size=DATA["tokens_per_file"] + ), + "2": _make_mmaps( + tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] + ), + "3": _make_mmaps( + tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] + ), + } + + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.33, + paths=source_paths["1"], + ), + SourceMixtureConfig(source_name="2", target_ratio=0.33, paths=source_paths["2"]), + SourceMixtureConfig( + source_name="3", + target_ratio=0.34, + paths=source_paths["3"], + ), + ] + + max_tokens = 5_000_000 + + with TemporaryDirectory() as tmp_dir: + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + output_dir=tmp_dir, + dry_run=True, + ) + + with capsys.disabled(): + print("\n") + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) + + def test_source_mixture_config_validation(): with pytest.raises(OLMoConfigurationError): SourceMixtureConfig( From 346135c83ad285ec9c6204e53e4b79fe3972b333 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 18 Oct 2024 14:21:06 -0700 Subject: [PATCH 03/57] Test cleanup --- src/olmo_core/data/mixture_dataset.py | 12 ++---------- src/test/data/mixture_dataset_test.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/olmo_core/data/mixture_dataset.py b/src/olmo_core/data/mixture_dataset.py index 43666ec3..155d5269 100644 --- a/src/olmo_core/data/mixture_dataset.py +++ b/src/olmo_core/data/mixture_dataset.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Dict, List -import smart_open import tabulate from olmo_core.aliases import PathOrStr @@ -14,7 +13,7 @@ from olmo_core.data import NumpyDatasetDType from olmo_core.data.utils import load_array_slice, memmap_to_write from olmo_core.exceptions import OLMoConfigurationError -from olmo_core.io import get_bytes_range, get_file_size +from olmo_core.io import get_file_size @dataclass @@ -186,6 +185,7 @@ def _write_tokens_for_source( written: List[str] = [] # Make sure we have enough paths to accommodate repetitions + # TODO: Need to make this go birrr for idx, path in enumerate(paths * math.ceil(source_config.max_repetion_ratio)): # Stop if we've taken enough tokens if tokens_taken >= tokens_to_take: @@ -228,11 +228,3 @@ def _bytes_to_tokens(self, num_bytes: int, dtype: NumpyDatasetDType) -> int: """ npdtype = dtype.as_np_dtype() return num_bytes // npdtype(int(0)).itemsize - - def _tokens_to_bytes(self, num_tokens: int, dtype: NumpyDatasetDType) -> int: - """ - Convert tokens to bytes based on the dtype. - """ - - npdtype = dtype.as_np_dtype() - return num_tokens * npdtype(int(0)).itemsize diff --git a/src/test/data/mixture_dataset_test.py b/src/test/data/mixture_dataset_test.py index 611f3877..26d8e365 100644 --- a/src/test/data/mixture_dataset_test.py +++ b/src/test/data/mixture_dataset_test.py @@ -18,7 +18,6 @@ DATA = { "dtype": NumpyDatasetDType.uint32, - "itemsize": np.dtype(np.uint32).itemsize, "tokens_per_file": 1_000_000, } @@ -171,7 +170,7 @@ def test_dataset_mixture_build(tmp_path: Path): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=NumpyDatasetDType.uint32, + dtype=DATA["dtype"], output_dir=tmp_dir, ) @@ -211,11 +210,11 @@ def test_dataset_mixture_build_insufficient_source_data(tmp_path: Path): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=NumpyDatasetDType.uint32, + dtype=DATA["dtype"], output_dir=tmp_dir, ) - # Should raise exception because the target ratio for source 1 @50% (2.5M) is infeasible without repetition + # Should raise exception because the target ratio for source 1 @50% (2.5M) is infeasible without repetition (default max_repetition_ratio=1) with pytest.raises(OLMoConfigurationError): config.build() @@ -259,7 +258,7 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=NumpyDatasetDType.uint32, + dtype=DATA["dtype"], output_dir=tmp_dir, ) @@ -305,7 +304,7 @@ def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=NumpyDatasetDType.uint32, + dtype=DATA["dtype"], output_dir=tmp_dir, ) @@ -351,7 +350,7 @@ def test_dataset_mixture_build_expected_files(tmp_path: Path): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=NumpyDatasetDType.uint32, + dtype=DATA["dtype"], output_dir=tmp_dir, ) @@ -410,7 +409,7 @@ def test_dataset_mixture_render_table(tmp_path: Path, capsys): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=NumpyDatasetDType.uint32, + dtype=DATA["dtype"], output_dir=tmp_dir, ) From 53def383cdfeefffcf5a1fef2f8c6e18bf2f9376 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 18 Oct 2024 18:13:43 -0700 Subject: [PATCH 04/57] WIP: Make it fast --- src/olmo_core/data/mixture_dataset.py | 127 +++++++++++++++++++------- src/test/data/mixture_dataset_test.py | 14 +-- 2 files changed, 103 insertions(+), 38 deletions(-) diff --git a/src/olmo_core/data/mixture_dataset.py b/src/olmo_core/data/mixture_dataset.py index 155d5269..d7df8deb 100644 --- a/src/olmo_core/data/mixture_dataset.py +++ b/src/olmo_core/data/mixture_dataset.py @@ -1,10 +1,13 @@ import math +import multiprocessing as mp import os import random -from concurrent.futures import ThreadPoolExecutor +import threading +from concurrent.futures import as_completed, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional +from tqdm import tqdm import tabulate @@ -16,13 +19,33 @@ from olmo_core.io import get_file_size +class ValueLock: + def __init__(self): + self._lock = threading.Lock() + self._value = 0 + + def increment(self) -> int: + with self._lock: + self._value += 1 + return self._value + + def add(self, value) -> int: + with self._lock: + self._value = self._value + value + return self._value + + def value(self) -> int: + with self._lock: + return self._value + + @dataclass class SourceMixtureConfig(Config): source_name: str target_ratio: float paths: List[PathOrStr] # 1.0 will result in a maximum of 1 repitition of the source data per epoch - max_repetion_ratio: float = 1.0 + max_repetition_ratio: float = 1.0 max_source_fraction: float = 1.0 def validate(self): @@ -57,7 +80,7 @@ def for_table(self, max_tokens: int) -> Dict: "source_population": f"{self.source_population:.2e}", "num_sampled": f"{self.num_selected:.2e}", "target_ratio": self.source.target_ratio, - "max_repetion_ratio": self.source.max_repetion_ratio, + "max_repetion_ratio": self.source.max_repetition_ratio, "max_source_fraction": self.source.max_source_fraction, "observed_source_ratio": self.num_selected / self.source_population, "observed_global_ratio": self.num_selected / max_tokens, @@ -117,8 +140,9 @@ def build(self) -> SourceMixtureDataset: # Count the number of tokens available for each source for source_config in self.source_configs: - available_tokens_by_source[source_config.source_name] = self._count_tokens_for_source( - source_config, self.dtype + print("Counting tokens for source: ", source_config.source_name) + available_tokens_by_source[source_config.source_name] = self._count_tokens_for_paths( + source_config.paths ) tokens_outcome_per_source: List[SourceTokenDetails] = [] @@ -130,7 +154,7 @@ def build(self) -> SourceMixtureDataset: max_for_source = int( available_for_source * source_config.max_source_fraction - * source_config.max_repetion_ratio + * source_config.max_repetition_ratio ) # Ensure that the available source tokens meet the target ratio requirement @@ -176,51 +200,92 @@ def _write_tokens_for_source( self, dtype: NumpyDatasetDType, tokens_to_take: int, source_config: SourceMixtureConfig ) -> List[str]: """ - Stream selected tokens into a local file based on selection criteria. + Write selected tokens into a local file based on selection criteria. """ # Shuffle the paths to avoid biasing our selection to sequential file paths paths = source_config.paths.copy() random.shuffle(paths) - tokens_taken = 0 written: List[str] = [] + taken = ValueLock() + m = mp.Manager() + write_lock = m.Lock() - # Make sure we have enough paths to accommodate repetitions - # TODO: Need to make this go birrr - for idx, path in enumerate(paths * math.ceil(source_config.max_repetion_ratio)): - # Stop if we've taken enough tokens - if tokens_taken >= tokens_to_take: - break + with ThreadPoolExecutor(max_workers=self.processes) as executor: + print(f"Collecting {tokens_to_take:.2e} tokens for {source_config.source_name}") + futures = [] + for idx, path in enumerate(paths * math.ceil(source_config.max_repetition_ratio)): + futures.append( + executor.submit( + self._load_and_write_tokens, + idx, + path, + dtype, + tokens_to_take, + source_config.source_name, + taken, + write_lock, + ) + ) - filename = f"{self.output_dir}/{idx:05}_{source_config.source_name}.npy" - nda = load_array_slice(path, 0, tokens_to_take - tokens_taken, dtype.as_np_dtype()) + for future in tqdm( + as_completed(futures), + total=len(futures), + desc=f"Processing {source_config.source_name}", + ): + written.append(future.result()) + + return [path for path in written if path is not None] + + def _load_and_write_tokens( + self, + index: int, + path: PathOrStr, + dtype: NumpyDatasetDType, + tokens_to_take: int, + source_name: str, + taken: ValueLock, + write_lock: threading.Lock, + ) -> Optional[str]: + """ + Load tokens from a source file and write them to a local file. + """ + if taken.value() >= tokens_to_take: + return None + + filename = f"{self.output_dir}/{index:05}_{source_name}.npy" + print(f"Fetching {path} for {source_name}") + nda = load_array_slice(path, 0, tokens_to_take, dtype.as_np_dtype()) + print(f"Fetched {len(nda):.2e} tokens for {source_name} {path}") + + # TODO: Why are we repeating files and or have empty arrays? + with write_lock: + nda = nda[: tokens_to_take - taken.value()] + if len(nda) <= 0: + print(f"Skipping {path} as it has no tokens left") + return None with memmap_to_write( path=Path(filename), shape=(len(nda),), dtype=dtype.as_np_dtype() ) as mm: mm[:] = nda + taken.add(len(nda)) + print(f"Wrote {len(nda):.2e} tokens to {filename}") - written.append(filename) - tokens_taken += tokens_to_take - tokens_taken + return filename - return written - - def _count_tokens_for_source( - self, source_config: SourceMixtureConfig, dtype: NumpyDatasetDType - ) -> int: + def _count_tokens_for_paths(self, paths: List[PathOrStr]) -> int: """ - Count the number of tokens for a set of source token files in parallel. + Count the number of tokens for a set of source files in parallel. Args: source_config: The source configuration. dtype: The data type of the source tokens. """ - def _count_tokens(path) -> int: - size = get_file_size(path) - tokens = self._bytes_to_tokens(size, dtype) - return tokens - with ThreadPoolExecutor(max_workers=self.processes) as executor: - return sum(executor.map(_count_tokens, source_config.paths)) + return sum(executor.map(self._count_tokens_for_file, paths)) + + def _count_tokens_for_file(self, path) -> int: + return self._bytes_to_tokens(get_file_size(path), self.dtype) def _bytes_to_tokens(self, num_bytes: int, dtype: NumpyDatasetDType) -> int: """ diff --git a/src/test/data/mixture_dataset_test.py b/src/test/data/mixture_dataset_test.py index 26d8e365..3a3c6fce 100644 --- a/src/test/data/mixture_dataset_test.py +++ b/src/test/data/mixture_dataset_test.py @@ -241,7 +241,7 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): SourceMixtureConfig( source_name="1", target_ratio=0.5, - max_repetion_ratio=3.0, # Allow 3x repetition of source1 so that we can meet the target of 2.5M + max_repetition_ratio=3.0, # Allow 3x repetition of source1 so that we can meet the target of 2.5M paths=source_paths["1"], ), SourceMixtureConfig(source_name="2", target_ratio=0.25, paths=source_paths["2"]), @@ -376,19 +376,19 @@ def test_dataset_mixture_build_expected_files(tmp_path: Path): def test_dataset_mixture_render_table(tmp_path: Path, capsys): source_paths = { "1": _make_mmaps( - tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] + tmp_path=tmp_path, prefix="source1", num_files=5, size=DATA["tokens_per_file"] ), "2": _make_mmaps( - tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] + tmp_path=tmp_path, prefix="source2", num_files=5, size=DATA["tokens_per_file"] ), "3": _make_mmaps( - tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] + tmp_path=tmp_path, prefix="source3", num_files=5, size=DATA["tokens_per_file"] ), } source_configs = [ SourceMixtureConfig( source_name="1", - target_ratio=0.10, + target_ratio=0.30, paths=source_paths["1"], ), SourceMixtureConfig( @@ -398,12 +398,12 @@ def test_dataset_mixture_render_table(tmp_path: Path, capsys): ), SourceMixtureConfig( source_name="3", - target_ratio=0.5, + target_ratio=0.30, paths=source_paths["3"], ), ] - max_tokens = 10 * 1000 + max_tokens = 10_123_000 with TemporaryDirectory() as tmp_dir: config = SourceMixtureDatasetConfig( From 8649cc8aa563b5307bb20725021cd9ba3d863174 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 18 Oct 2024 18:25:37 -0700 Subject: [PATCH 05/57] WIP: Simple benchmark --- .../benchmark/data/mixture_dataset_bm.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100755 src/scripts/benchmark/data/mixture_dataset_bm.py diff --git a/src/scripts/benchmark/data/mixture_dataset_bm.py b/src/scripts/benchmark/data/mixture_dataset_bm.py new file mode 100755 index 00000000..e2ff2727 --- /dev/null +++ b/src/scripts/benchmark/data/mixture_dataset_bm.py @@ -0,0 +1,60 @@ +""" +Build a mixture dataset from a list of source datasets and benchmark it. +""" + +import logging +import os +import time +from tempfile import TemporaryDirectory + +import s3fs + +from olmo_core.data import NumpyDatasetDType +from olmo_core.data.mixture_dataset import SourceMixtureDatasetConfig, SourceMixtureConfig + +log = logging.getLogger(__name__) + + +def build_config(output_dir, processes) -> SourceMixtureDatasetConfig: + s3 = s3fs.S3FileSystem() + books = s3.glob("s3://ai2-llm/preprocessed/books/allenai_dolma2/*.npy") + dclm = s3.glob( + "s3://ai2-llm/preprocessed/dclm/text_openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train/allenai/dolma2-tokenizer/*.npy" + ) + + print(f"Found {len(books)} books files") + print(f"Found {len(dclm)} dclm files") + + return SourceMixtureDatasetConfig( + max_tokens=1_000_000_000, + source_configs=[ + SourceMixtureConfig( + source_name="books", + paths=[f"s3://{path}" for path in books], + max_repetition_ratio=1.0, + target_ratio=0.1, + ), + SourceMixtureConfig( + source_name="dclm", + paths=[f"s3://{path}" for path in dclm], + target_ratio=0.9, + ), + ], + dtype=NumpyDatasetDType.uint32, + output_dir=output_dir, + processes=processes, + seed=42, + dry_run=False, + ) + + +if __name__ == "__main__": + with TemporaryDirectory() as temp_dir: + processes = os.cpu_count() + # TODO: ADD DRY RUN TIME + print(f"Running with {processes} processes") + config_a = build_config(temp_dir, processes) + start_time = time.time() + dataset = config_a.build() + end_time = time.time() + print(f"Built dataset in {end_time - start_time:.2f} seconds") From e3d7011a5dcb093e0c663b7f8cc12ae84b5a0bda Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 15:17:19 -0700 Subject: [PATCH 06/57] WIP: Refactor --- src/examples/train_with_mixture.py | 231 ++++++++++++++ src/olmo_core/data/__init__.py | 5 +- src/olmo_core/data/mixture_dataset.py | 295 ------------------ src/olmo_core/data/numpy_dataset.py | 288 ++++++++++++++--- src/olmo_core/data/source_mixture.py | 270 ++++++++++++++++ src/olmo_core/data/types.py | 41 +++ src/olmo_core/data/utils.py | 3 + .../benchmark/data/mixture_dataset_bm.py | 60 ---- src/test/data/numpy_dataset_test.py | 86 ++++- ...dataset_test.py => source_mixture_test.py} | 254 ++++----------- 10 files changed, 943 insertions(+), 590 deletions(-) create mode 100644 src/examples/train_with_mixture.py delete mode 100644 src/olmo_core/data/mixture_dataset.py create mode 100644 src/olmo_core/data/source_mixture.py create mode 100644 src/olmo_core/data/types.py delete mode 100755 src/scripts/benchmark/data/mixture_dataset_bm.py rename src/test/data/{mixture_dataset_test.py => source_mixture_test.py} (54%) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py new file mode 100644 index 00000000..3d826142 --- /dev/null +++ b/src/examples/train_with_mixture.py @@ -0,0 +1,231 @@ +""" +Example of how to train a transformer language model. + +Launch this with torchrun: + + torchrun --nproc-per-node=4 src/examples/train.py run_name [OVERRIDES...] +""" + +import sys +from dataclasses import dataclass +from typing import List, cast, Union + +import s3fs + +from olmo_core.config import Config, DType +from olmo_core.data import ( + NumpyDataLoaderConfig, + NumpyDatasetConfig, + NumpyFSLDatasetMixtureConfig, + NumpyDatasetType, + TokenizerConfig, +) +from olmo_core.data.types import NumpyDatasetDType +from olmo_core.data.source_mixture import SourceMixtureConfig, SourceMixtureDatasetConfig +from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType +from olmo_core.distributed.utils import init_hybrid_shard_mesh +from olmo_core.nn.transformer import TransformerConfig +from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride +from olmo_core.train import ( + Duration, + TrainerConfig, + prepare_training_environment, + teardown_training_environment, +) +from olmo_core.train.callbacks import ( + CheckpointerCallback, + CometCallback, + ConfigSaverCallback, + GPUMemoryMonitorCallback, + GradClipperCallback, + LMEvaluatorCallbackConfig, + ProfilerCallback, + SchedulerCallback, + SequenceLengthSchedulerCallback, + WandBCallback, +) +from olmo_core.utils import get_default_device, seed_all + + +@dataclass +class ExperimentConfig(Config): + model: TransformerConfig + optim: AdamWConfig + dataset: Union[NumpyDatasetConfig, NumpyFSLDatasetMixtureConfig] + data_loader: NumpyDataLoaderConfig + trainer: TrainerConfig + init_seed: int = 12536 + + +def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: + tokenizer_config = TokenizerConfig.gpt2() + + model_config = TransformerConfig.llama2_271M( + vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128 + compile=True, + dp_config=DataParallelConfig( + name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 + ), + ) + + optim_config = AdamWConfig( + lr=1e-3, + group_overrides=[ + OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) + ], + ) + + s3 = s3fs.S3FileSystem() + + # DCLM docs + baseline = s3.glob( + "s3://ai2-llm/preprocessed/dclm/samples/src-100b/**/allenai/dolma2-tokenizer/*.npy" + ) + rewrites = s3.glob( + "s3://ai2-llm/preprocessed/dclm/samples/rewrite-100b/**/allenai/dolma2-tokenizer/*.npy" + ) + + sequence_length = 1024 + source_config = SourceMixtureDatasetConfig( + max_tokens=20_000_000, + sequence_length=sequence_length, + source_configs=[ + SourceMixtureConfig( + paths=[f"s3://{path}" for path in baseline], + source_name="baseline", + max_repetition_ratio=1.0, + target_ratio=0.7, + ), + SourceMixtureConfig( + source_name="rewrites", + paths=[f"s3://{path}" for path in rewrites], + target_ratio=0.3, + ), + ], + dtype=NumpyDatasetDType.uint32, + seed=42, + ) + + dataset_config = NumpyFSLDatasetMixtureConfig( + source_mixture_config=source_config, + sequence_length=sequence_length, + max_target_sequence_length=8192, + tokenizer=TokenizerConfig.dolma2(), + work_dir="/tmp/dataset-cache", + bust_index_cache=True, + ) + + data_loader_config = NumpyDataLoaderConfig( + global_batch_size=256 * 1024, + seed=0, + num_workers=4, + ) + + trainer_config = ( + TrainerConfig( + save_folder=f"/tmp/{run_name}", + rank_microbatch_size=16 * 1024, + save_overwrite=True, + metrics_collect_interval=5, + cancel_check_interval=5, + ) + .with_callback("lr_scheduler", SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=100))) + .with_callback( + "seq_len_scheduler", + SequenceLengthSchedulerCallback( + min_sequence_length=128, warmup_steps=100, enabled=False + ), + ) + .with_callback("gpu_monitor", GPUMemoryMonitorCallback()) + .with_callback("grad_clipper", GradClipperCallback(max_grad_norm=1.0)) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=1000, + ephemeral_save_interval=100, + save_async=True, + ), + ) + .with_callback( + "comet", + CometCallback( + name=run_name, + cancel_check_interval=10, + enabled=False, # change to true to enable + ), + ) + .with_callback( + "wandb", + WandBCallback( + name=run_name, + cancel_check_interval=10, + enabled=False, # change to true to enable + ), + ) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback("profiler", ProfilerCallback(enabled=False)) + .with_callback( + "evaluator", + LMEvaluatorCallbackConfig( + eval_dataset=NumpyDatasetConfig( + paths=["/net/nfs/allennlp/llm-data/c4/en/c4-validation.00000-00008.npy"], + metadata=[{"label": "c4-validation"}], + name=NumpyDatasetType.padded_fsl, + sequence_length=sequence_length, + tokenizer=tokenizer_config, + work_dir="/tmp/dataset-cache", + ), + eval_interval=250, + eval_duration=Duration.steps(10), + ), + ) + ) + + return ExperimentConfig( + model=model_config, + optim=optim_config, + dataset=dataset_config, + data_loader=data_loader_config, + trainer=trainer_config, + ).merge(overrides) + + +def main(run_name: str, overrides: List[str]): + config = build_config(run_name, overrides) + + # Set RNG states on all devices. + seed_all(config.init_seed) + + # Build components. + model = config.model.build( + init_device="meta", + device=get_default_device(), + dp_mesh=init_hybrid_shard_mesh(num_replicas=2), + ) + optim = config.optim.build(model) + dataset = config.dataset.build() + data_loader = config.data_loader.build(dataset) + trainer = config.trainer.build(model, optim, data_loader) + + # Save config to W&B and each checkpoint dir. + config_dict = config.as_config_dict() + cast(CometCallback, trainer.callbacks["comet"]).config = config_dict + cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict + cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict + + # Train. + trainer.fit() + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} run_name [OVERRIDES...]") + sys.exit(1) + + run_name, *overrides = sys.argv[1:] + + prepare_training_environment() + try: + main(run_name, overrides=overrides) + finally: + teardown_training_environment() diff --git a/src/olmo_core/data/__init__.py b/src/olmo_core/data/__init__.py index b3920e3a..3ddd0d23 100644 --- a/src/olmo_core/data/__init__.py +++ b/src/olmo_core/data/__init__.py @@ -24,8 +24,7 @@ from .numpy_dataset import ( NumpyDatasetBase, NumpyDatasetConfig, - NumpyDatasetDType, - NumpyDatasetType, + NumpyFSLDatasetMixtureConfig, NumpyFSLDataset, NumpyPaddedFSLDataset, NumpyVSLDataset, @@ -38,10 +37,12 @@ VSLNaturalCurriculum, ) from .tokenizer import TokenizerConfig, TokenizerName +from .types import NumpyDatasetDType, NumpyDatasetType __all__ = [ "NumpyDatasetBase", "NumpyFSLDataset", + "NumpyFSLDatasetMixtureConfig", "NumpyPaddedFSLDataset", "NumpyVSLDataset", "VSLCurriculum", diff --git a/src/olmo_core/data/mixture_dataset.py b/src/olmo_core/data/mixture_dataset.py deleted file mode 100644 index d7df8deb..00000000 --- a/src/olmo_core/data/mixture_dataset.py +++ /dev/null @@ -1,295 +0,0 @@ -import math -import multiprocessing as mp -import os -import random -import threading -from concurrent.futures import as_completed, ThreadPoolExecutor -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, List, Optional -from tqdm import tqdm - -import tabulate - -from olmo_core.aliases import PathOrStr -from olmo_core.config import Config -from olmo_core.data import NumpyDatasetDType -from olmo_core.data.utils import load_array_slice, memmap_to_write -from olmo_core.exceptions import OLMoConfigurationError -from olmo_core.io import get_file_size - - -class ValueLock: - def __init__(self): - self._lock = threading.Lock() - self._value = 0 - - def increment(self) -> int: - with self._lock: - self._value += 1 - return self._value - - def add(self, value) -> int: - with self._lock: - self._value = self._value + value - return self._value - - def value(self) -> int: - with self._lock: - return self._value - - -@dataclass -class SourceMixtureConfig(Config): - source_name: str - target_ratio: float - paths: List[PathOrStr] - # 1.0 will result in a maximum of 1 repitition of the source data per epoch - max_repetition_ratio: float = 1.0 - max_source_fraction: float = 1.0 - - def validate(self): - if self.target_ratio: - if not 0 <= self.target_ratio <= 1: - raise OLMoConfigurationError("target_ratio must be in the range [0, 1]") - if not 0 <= self.max_source_fraction <= 1: - raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]") - if self.max_source_fraction < self.target_ratio: - raise OLMoConfigurationError("max_source_fraction must be >= target_ratio") - - if not self.paths: - raise OLMoConfigurationError("paths must not be empty") - - if not 0 <= self.max_source_fraction <= 1: - raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]") - - -@dataclass -class SourceTokenDetails: - """ - A class to hold intermediate selection details for a mixture source. - """ - - source: SourceMixtureConfig - source_population: int - num_selected: int - - def for_table(self, max_tokens: int) -> Dict: - return { - "source_name": self.source.source_name, - "source_population": f"{self.source_population:.2e}", - "num_sampled": f"{self.num_selected:.2e}", - "target_ratio": self.source.target_ratio, - "max_repetion_ratio": self.source.max_repetition_ratio, - "max_source_fraction": self.source.max_source_fraction, - "observed_source_ratio": self.num_selected / self.source_population, - "observed_global_ratio": self.num_selected / max_tokens, - } - - -@dataclass -class SourceMixture: - """ - A fractionalized mixture of source tokens. - """ - - source_name: str - paths: List[str] - - -@dataclass -class SourceMixtureDataset: - """ - A dataset consisting of a fractionalized mixture of data sources. - """ - - sources: List[SourceMixture] - - -@dataclass -class SourceMixtureDatasetConfig(Config): - """ - A configuration class for building a dataset from a fractionalized mixture of sources. - """ - - max_tokens: int - source_configs: List[SourceMixtureConfig] - dtype: NumpyDatasetDType - output_dir: PathOrStr - processes: int = 1 - seed: int = 42 - dry_run: bool = False - - def __post_init__(self): - os.makedirs(self.output_dir, exist_ok=True) - - def validate(self): - if self.max_tokens <= 0: - raise OLMoConfigurationError("max_tokens must be > 0") - - if not self.source_configs: - raise OLMoConfigurationError("source_configs must not be empty") - - if (total := sum([source.target_ratio for source in self.source_configs])) != 1.0: - raise OLMoConfigurationError(f"target_ratios must sum to 1, got {total}") - - def build(self) -> SourceMixtureDataset: - self.validate() - random.seed(self.seed) - available_tokens_by_source: Dict[str, int] = {} - - # Count the number of tokens available for each source - for source_config in self.source_configs: - print("Counting tokens for source: ", source_config.source_name) - available_tokens_by_source[source_config.source_name] = self._count_tokens_for_paths( - source_config.paths - ) - - tokens_outcome_per_source: List[SourceTokenDetails] = [] - - # Calculate the number of tokens to include for each source - for source_config in self.source_configs: - available_for_source = available_tokens_by_source[source_config.source_name] - target_for_source = int(self.max_tokens * source_config.target_ratio) - max_for_source = int( - available_for_source - * source_config.max_source_fraction - * source_config.max_repetition_ratio - ) - - # Ensure that the available source tokens meet the target ratio requirement - if not max_for_source >= target_for_source: - raise OLMoConfigurationError( - f"Insufficient tokens for source: {source_config.source_name} @ target global ratio: {source_config.target_ratio} :: {max_for_source} < {target_for_source}" - ) - - tokens_outcome_per_source.append( - SourceTokenDetails( - source=source_config, - source_population=available_for_source, - num_selected=target_for_source, - ) - ) - - completed = [] - if not self.dry_run: - for outcome in tokens_outcome_per_source: - completed.append(self._handle_source_outcome(outcome)) - - print(f"Mixing outcome by source: {'' if not self.dry_run else '(DRY RUN)'}") - print( - tabulate.tabulate( - [item.for_table(self.max_tokens) for item in tokens_outcome_per_source], - headers="keys", - tablefmt="pretty", - ), - ) - - return SourceMixtureDataset(completed) - - def _handle_source_outcome(self, outcome: SourceTokenDetails) -> SourceMixture: - """ - Write selected tokens for a source to a local file and return the path. - """ - return SourceMixture( - source_name=outcome.source.source_name, - paths=self._write_tokens_for_source(self.dtype, outcome.num_selected, outcome.source), - ) - - def _write_tokens_for_source( - self, dtype: NumpyDatasetDType, tokens_to_take: int, source_config: SourceMixtureConfig - ) -> List[str]: - """ - Write selected tokens into a local file based on selection criteria. - """ - # Shuffle the paths to avoid biasing our selection to sequential file paths - paths = source_config.paths.copy() - random.shuffle(paths) - written: List[str] = [] - taken = ValueLock() - m = mp.Manager() - write_lock = m.Lock() - - with ThreadPoolExecutor(max_workers=self.processes) as executor: - print(f"Collecting {tokens_to_take:.2e} tokens for {source_config.source_name}") - futures = [] - for idx, path in enumerate(paths * math.ceil(source_config.max_repetition_ratio)): - futures.append( - executor.submit( - self._load_and_write_tokens, - idx, - path, - dtype, - tokens_to_take, - source_config.source_name, - taken, - write_lock, - ) - ) - - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=f"Processing {source_config.source_name}", - ): - written.append(future.result()) - - return [path for path in written if path is not None] - - def _load_and_write_tokens( - self, - index: int, - path: PathOrStr, - dtype: NumpyDatasetDType, - tokens_to_take: int, - source_name: str, - taken: ValueLock, - write_lock: threading.Lock, - ) -> Optional[str]: - """ - Load tokens from a source file and write them to a local file. - """ - if taken.value() >= tokens_to_take: - return None - - filename = f"{self.output_dir}/{index:05}_{source_name}.npy" - print(f"Fetching {path} for {source_name}") - nda = load_array_slice(path, 0, tokens_to_take, dtype.as_np_dtype()) - print(f"Fetched {len(nda):.2e} tokens for {source_name} {path}") - - # TODO: Why are we repeating files and or have empty arrays? - with write_lock: - nda = nda[: tokens_to_take - taken.value()] - if len(nda) <= 0: - print(f"Skipping {path} as it has no tokens left") - return None - with memmap_to_write( - path=Path(filename), shape=(len(nda),), dtype=dtype.as_np_dtype() - ) as mm: - mm[:] = nda - taken.add(len(nda)) - print(f"Wrote {len(nda):.2e} tokens to {filename}") - - return filename - - def _count_tokens_for_paths(self, paths: List[PathOrStr]) -> int: - """ - Count the number of tokens for a set of source files in parallel. - - Args: - source_config: The source configuration. - dtype: The data type of the source tokens. - """ - - with ThreadPoolExecutor(max_workers=self.processes) as executor: - return sum(executor.map(self._count_tokens_for_file, paths)) - - def _count_tokens_for_file(self, path) -> int: - return self._bytes_to_tokens(get_file_size(path), self.dtype) - - def _bytes_to_tokens(self, num_bytes: int, dtype: NumpyDatasetDType) -> int: - """ - Convert bytes to tokens based on the dtype. - """ - npdtype = dtype.as_np_dtype() - return num_bytes // npdtype(int(0)).itemsize diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index b6c9ad73..5f6c6ae3 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -30,6 +30,8 @@ from torch.utils.data import Dataset from olmo_core.exceptions import OLMoConfigurationError, OLMoEnvironmentError +from olmo_core.data.types import NumpyDatasetType, NumpyDatasetDType, SupportedDType +from olmo_core.data.source_mixture import SourceMixtureDatasetConfig, SourcePathTokens from ..aliases import PathOrStr from ..config import Config, StrEnum @@ -53,6 +55,7 @@ __all__ = [ "NumpyDatasetBase", "NumpyFSLDataset", + "NumpyFSLDatasetMixtureConfig", "NumpyPaddedFSLDataset", "VSLCurriculum", "VSLNaturalCurriculum", @@ -60,11 +63,9 @@ "VSLGrowP2Curriculum", "VSLGrowLinearCurriculum", "NumpyVSLDataset", - "NumpyDatasetType", "NumpyDatasetConfig", "VSLCurriculumType", "VSLCurriculumConfig", - "NumpyDatasetDType", ] @@ -99,7 +100,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16, + dtype: SupportedDType = np.uint16, ): if not paths: raise OLMoConfigurationError("At least one path is required") @@ -153,7 +154,7 @@ def vocab_size(self) -> int: @property def dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> SupportedDType: """ The numpy datatype of the arrays. """ @@ -347,7 +348,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16, + dtype: SupportedDType = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, generate_doc_lengths: bool = False, @@ -516,7 +517,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16, + dtype: SupportedDType = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, ): @@ -550,7 +551,7 @@ def offsets(self) -> Tuple[Tuple[int, int], ...]: @property def indices_dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> SupportedDType: return np.uint32 def prepare(self): @@ -945,7 +946,7 @@ def __init__( max_sequence_length: int, min_sequence_length: int = 256, curriculum: Optional[VSLCurriculum] = None, - dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16, + dtype: SupportedDType = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, ): @@ -985,9 +986,7 @@ def __init__( self._curriculum = curriculum or VSLNaturalCurriculum() self._num_instances: Optional[int] = None self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None - self._lengths_dtype: Optional[ - Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] - ] = None + self._lengths_dtype: Optional[SupportedDType] = None self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None @property @@ -1226,13 +1225,13 @@ def instances_per_bucket(self) -> Tuple[Tuple[int, int], ...]: @property def indices_dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> SupportedDType: return np.uint32 @property def lengths_dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> SupportedDType: if self._lengths_dtype is None: for dtype in ( np.uint8, @@ -1247,39 +1246,6 @@ def lengths_dtype( return self._lengths_dtype -class NumpyDatasetType(StrEnum): - """ - An enumeration of the different :class:`NumpyDatasetBase` implementations. - """ - - fsl = "fsl" - """ - Fixed sequenced length ➡️ :class:`NumpyFSLDataset`. - """ - - padded_fsl = "padded_fsl" - """ - Padded fixed sequence length ➡️ :class:`NumpyPaddedFSLDataset`. - """ - - vsl = "vsl" - """ - Variable sequenced length ➡️ :class:`NumpyVSLDataset`. - """ - - -class NumpyDatasetDType(StrEnum): - uint8 = "uint8" - uint16 = "uint16" - uint32 = "uint32" - uint64 = "uint64" - - def as_np_dtype( - self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: - return getattr(np, str(self)) - - class VSLCurriculumType(StrEnum): """ An enumeration of the different VSL curriculum implementations. @@ -1480,7 +1446,7 @@ def from_data_mix( def get_dtype( self, - ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + ) -> SupportedDType: if self.dtype is not None: return NumpyDatasetDType(self.dtype).as_np_dtype() @@ -1638,3 +1604,231 @@ def build(self) -> NumpyDatasetBase: dataset.work_dir = Path(self.work_dir) return dataset + + +class NumpyFSLDatasetMixture(NumpyFSLDataset): + def __init__( + self, + *paths: PathOrStr, + path_offset_index: Dict[str, int], + sequence_length: int, + pad_token_id: int, + eos_token_id: int, + vocab_size: int, + dtype: SupportedDType = np.uint16, + metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + include_instance_metadata: Optional[bool] = None, + generate_doc_lengths: bool = False, + max_target_sequence_length: Optional[int] = None, + bust_index_cache: bool = False, + ): + super().__init__( + *paths, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + vocab_size=vocab_size, + dtype=dtype, + sequence_length=sequence_length, + metadata=metadata, + include_instance_metadata=include_instance_metadata, + generate_doc_lengths=generate_doc_lengths, + max_target_sequence_length=max_target_sequence_length, + ) + self._metadata = metadata + self._include_instance_metadata = include_instance_metadata + self._num_instances: Optional[int] = None + self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None + self._lengths_dtype: Optional[SupportedDType] = None + self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None + self._path_offset_index = path_offset_index + self._bust_index_cache = bust_index_cache + + # TODO: overload __getitem__ to read the stuff we need, maybe just with read_chunk_from_array + + def _get_indices_path(self, path: PathOrStr) -> Path: + sha256_hash = hashlib.sha256() + sha256_hash.update(str(path).encode()) + sha256_hash.update(str(self._get_file_size(path)).encode()) + path_hash = sha256_hash.hexdigest() + return ( + self.work_dir + / "dataset-common" + / f"mixture-instance-indices-{self.sequence_length}-{path_hash}.npy" + ) + + def _write_document_indices(self): + paths_needed: List[PathOrStr] = [] + for path in self.paths: + indices_path = self._get_indices_path(path) + if indices_path.is_file() and not self._bust_index_cache: + log.info(f"Reusing document indices for '{path}' at:\n'{indices_path}'") + elif path not in paths_needed: + paths_needed.append(path) + + if paths_needed: + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] + for path in paths_needed: + indices_path = self._get_indices_path(path) + log.info(f"Gathering instance indices for '{path}'...") + # NOTE: We limit the number of instances to the number by total target token count + max_instances = self._path_offset_index[str(path)] // self.sequence_length + future = executor.submit( + run_worker_func, + segment_documents_into_instances, + path, + indices_path, + max_sequence_length=self.sequence_length, + eos_token_id=self.eos_token_id, + dtype=self.dtype, + indices_dtype=self.dtype, + max_instances=max_instances, + ) + futures.append(future) + + concurrent.futures.wait(futures, return_when="ALL_COMPLETED") + + # Log results. + for path, future in zip(paths_needed, futures): + _, total_instances = future.result() + log.info( + f"Created {total_instances:,d} instances of sequence length up to " + f"{self.sequence_length} from '{path}'" + ) + + def prepare(self): + if self.fs_local_rank == 0: + log.info("Gathering indices...") + self._write_document_indices() + barrier() + len(self) + + def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor: + start_idx = index * self.sequence_length + return load_array_slice_into_tensor( + path, start_idx, start_idx + self.sequence_length, self.dtype + ) + + def _get_file_size_and_length( + self, path: PathOrStr, dtype: Optional[SupportedDType] = None + ) -> Tuple[int, int]: + dtype = dtype or self.dtype + item_size = dtype(0).itemsize + file_size = self._get_size_from_offset_index(path) + if ( + self.max_target_sequence_length is None + or self.max_target_sequence_length == self.sequence_length + ): + return file_size, file_size // (item_size * self.sequence_length) + elif self.max_target_sequence_length > self.sequence_length: + num_max_seq_len_instances = file_size // (item_size * self.max_target_sequence_length) + return ( + file_size, + num_max_seq_len_instances + * (self.max_target_sequence_length // self.sequence_length), + ) + else: + raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'") + + def _get_size_from_offset_index(self, path: PathOrStr) -> int: + try: + # Get size in bytes from tokens in the supplied index * itemsize + return self._path_offset_index[str(path)] * self.dtype(0).itemsize + except KeyError: + raise OLMoEnvironmentError(f"Path {path} not found in path index") + + +@dataclass +class NumpyFSLDatasetMixtureConfig(Config): + """ + A config class for easily building :class:`NumpyFSLDatasetMixture` class. + This is a special case of :class:`NumpyFSLDataset` that is built from a mixture of source + datasets based on a mixture configuration. + """ + + source_mixture_config: SourceMixtureDatasetConfig + """ + The source mixture dataset config. + """ + tokenizer: TokenizerConfig + """ + The tokenizer config. + """ + sequence_length: Optional[int] = None + """ + The sequence length for a :class:`NumpyFSLDataset`. + """ + max_target_sequence_length: Optional[int] = None + """ + The max target sequene length for a :class:`NumpyFSLDataset`. + """ + dtype: Optional[NumpyDatasetDType] = None + """ + The numpy datatype of the token ID arrays. + """ + metadata: Optional[List[Dict[str, Any]]] = None + """ + Metadata for the numpy arrays. + """ + include_instance_metadata: bool = True + """ + Whether or not to include the :data:`metadata` in the instances returned from + :meth:`NumpyDatasetBase.__getitem__()`. + """ + generate_doc_lengths: bool = False + """ + Include individual document lengths in the instances returned from + :meth:`NumpyDatasetBase.__getitem__()`. + """ + work_dir: Optional[str] = None + """ + The dataset working directory. This is used to cache working files like shuffled indices, + instance buckets, etc. + + .. tip:: + You can save a lot of time and disk space by setting this to a common directory across + all of you runs. + """ + bust_index_cache: bool = False + """ + Whether or not to bust the index cache. + """ + + def get_dtype( + self, + ) -> SupportedDType: + if self.dtype is not None: + return NumpyDatasetDType(self.dtype).as_np_dtype() + + # Guess based on vocab size. + for dtype in ( + NumpyDatasetDType.uint8, + NumpyDatasetDType.uint16, + NumpyDatasetDType.uint32, + NumpyDatasetDType.uint64, + ): + if (self.tokenizer.vocab_size - 1) <= np.iinfo(dtype.as_np_dtype()).max: + log.info(f"Assuming dtype '{dtype}' based on vocab size") + return dtype.as_np_dtype() + + raise ValueError("vocab size too big!") + + def build(self) -> NumpyFSLDataset: + """ + Construct the corresponding :class:`NumpyFSLDatasetMixture`. + """ + mixture = self.source_mixture_config.build().to_path_instance_index() + return NumpyFSLDatasetMixture( + *mixture.keys(), + sequence_length=self.sequence_length or 1024, + max_target_sequence_length=self.max_target_sequence_length, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + vocab_size=self.tokenizer.vocab_size, + dtype=self.get_dtype(), + metadata=self.metadata, + include_instance_metadata=self.include_instance_metadata, + generate_doc_lengths=self.generate_doc_lengths, + path_offset_index=mixture, + bust_index_cache=self.bust_index_cache, + ) diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py new file mode 100644 index 00000000..cfb4d7b0 --- /dev/null +++ b/src/olmo_core/data/source_mixture.py @@ -0,0 +1,270 @@ +import logging +import math +import random +from itertools import chain +from concurrent.futures import as_completed, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Dict, List, Optional + +import tabulate +from tqdm import tqdm + +from olmo_core.aliases import PathOrStr +from olmo_core.config import Config +from olmo_core.data.types import NumpyDatasetDType +from olmo_core.exceptions import OLMoConfigurationError +from olmo_core.io import get_file_size + +__all__ = [ + "SourceMixtureConfig", + "SourceMixtureDataset", + "SourceMixtureDatasetConfig", +] + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger(__name__) + +# Disable some noisy loggers +for name in logging.Logger.manager.loggerDict.keys(): + if name in ( + "boto", + "urllib3", + "s3transfer", + "boto3", + "botocore", + "aiobotocore", + "nose", + ): + logging.getLogger(name).setLevel(logging.CRITICAL) + + +@dataclass +class SourceMixtureConfig(Config): + source_name: str + target_ratio: float + paths: List[PathOrStr] + # 1.0 will result in a maximum of 1 repitition of the source data per epoch + max_repetition_ratio: float = 1.0 + max_source_fraction: float = 1.0 + + def validate(self): + if self.target_ratio: + if not 0 <= self.target_ratio <= 1: + raise OLMoConfigurationError("target_ratio must be in the range [0, 1]") + if not 0 <= self.max_source_fraction <= 1: + raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]") + if self.max_source_fraction < self.target_ratio: + raise OLMoConfigurationError("max_source_fraction must be >= target_ratio") + + if not self.paths: + raise OLMoConfigurationError("paths must not be empty") + + if not 0 <= self.max_source_fraction <= 1: + raise OLMoConfigurationError("max_source_fraction must be in the range [0, 1]") + + +@dataclass +class SourceTokenDetails: + """ + A class to hold intermediate selection details for a mixture source. + """ + + config: SourceMixtureConfig + population: int + num_selected: int + + def for_table(self, max_tokens: int) -> Dict: + return { + "source_name": self.config.source_name, + "source_population": f"{self.population:.2e}", + "num_sampled": f"{self.num_selected:.2e}", + "target_ratio": self.config.target_ratio, + "max_repetion_ratio": self.config.max_repetition_ratio, + "max_source_fraction": self.config.max_source_fraction, + "observed_source_ratio": f"{(self.num_selected / self.population):.4}", + "observed_global_ratio": f"{(self.num_selected / max_tokens):.4}", + } + + +@dataclass +class SourcePathTokens: + path: PathOrStr + tokens: int + + +@dataclass +class SourceMixtureOutcome: + name: str + path_tokens: List[SourcePathTokens] + + +@dataclass +class SourceMixtureDataset: + """ + A dataset consisting of a fractionalized mixture of data sources. + """ + + sources: List[SourceMixtureOutcome] + + def to_path_instance_index(self) -> Dict[str, int]: + """ + Convert the dataset to a dictionary of paths and instance counts to retain. + """ + outcomes = chain.from_iterable([outcome.path_tokens for outcome in self.sources]) + return {str(outcome.path): outcome.tokens for outcome in outcomes} + + +@dataclass +class SourceMixtureDatasetConfig(Config): + """ + A configuration class for building a dataset from a fractionalized mixture of sources. + """ + + max_tokens: int + source_configs: List[SourceMixtureConfig] + sequence_length: int + dtype: NumpyDatasetDType + processes: int = 1 + seed: int = 42 + + def validate(self): + if self.max_tokens <= 0: + raise OLMoConfigurationError("max_tokens must be > 0") + + if not self.source_configs: + raise OLMoConfigurationError("source_configs must not be empty") + + if (total := sum([source.target_ratio for source in self.source_configs])) != 1.0: + raise OLMoConfigurationError(f"target_ratios must sum to 1, got {total}") + + def build(self) -> SourceMixtureDataset: + self.validate() + random.seed(self.seed) + available_tokens_by_source: Dict[str, int] = {} + + # Count the number of tokens available for each source + for source_config in self.source_configs: + log.info(f"Counting tokens for source: {source_config.source_name}") + available_tokens_by_source[source_config.source_name] = self._count_tokens_for_paths( + paths=source_config.paths, source=source_config.source_name + ) + + tokens_details_by_source: List[SourceTokenDetails] = [] + + # Calculate the number of tokens to include for each source + for source_config in self.source_configs: + num_for_source = available_tokens_by_source[source_config.source_name] + needed_for_source = int(self.max_tokens * source_config.target_ratio) + max_for_source = int( + (num_for_source * source_config.max_source_fraction) + * source_config.max_repetition_ratio + ) + + # Ensure that the max tokens for a source meet the target ratio requirement + if max_for_source < needed_for_source: + raise OLMoConfigurationError( + f"Insufficient tokens for source: {source_config.source_name} @ target global ratio: {source_config.target_ratio} :: {max_for_source} < {needed_for_source}" + ) + + tokens_details_by_source.append( + SourceTokenDetails( + config=source_config, + population=num_for_source, + num_selected=needed_for_source, + ) + ) + + completed: List[SourceMixtureOutcome] = [] + for source in tokens_details_by_source: + completed.append( + SourceMixtureOutcome( + name=source.config.source_name, + path_tokens=self.get_paths_and_tokens_for_source( + source_config=source.config, + take_ratio=source.num_selected / source.population, + ), + ) + ) + + log.info("Outcome by source => ") + print( + tabulate.tabulate( + [item.for_table(self.max_tokens) for item in tokens_details_by_source], + headers="keys", + tablefmt="pretty", + ), + ) + + total_tokens = sum([item.population for item in tokens_details_by_source]) + selected_tokens = sum([item.num_selected for item in tokens_details_by_source]) + observed_global_ratio = selected_tokens / total_tokens + + log.info("Global outcome => ") + print( + tabulate.tabulate( + [ + { + "total_tokens": f"{total_tokens:.2e}", + "selected_tokens": f"{selected_tokens:.2e}", + "observed_global_ratio": f"{observed_global_ratio:.4}", + } + ], + tablefmt="pretty", + headers="keys", + ), + ) + + for source in completed: + for item in source.path_tokens: + log.info(f"Selected {item.tokens} tokens from {source.name} at {item.path}") + + return SourceMixtureDataset(completed) + + def get_paths_and_tokens_for_source( + self, source_config: SourceMixtureConfig, take_ratio: float + ) -> List[SourcePathTokens]: + """ + Get the paths and resulting token count for a source. + """ + # TODO: Handle repetition ratio by adding paths multiple times, max_repetition_ratio + path_tokens = [] + for path in source_config.paths: + tokens_to_keep = int(math.ceil(self._count_tokens_for_file(path) * take_ratio)) + path_tokens.append(SourcePathTokens(path=path, tokens=tokens_to_keep)) + + return path_tokens + + def _count_tokens_for_paths(self, paths: List[PathOrStr], source: Optional[str]) -> int: + """ + Count the number of tokens for a set of source files in parallel. + + Args: + source_config: The source configuration. + dtype: The data type of the source tokens. + """ + + with ThreadPoolExecutor(max_workers=self.processes) as executor: + futures = [] + for path in paths: + futures.append(executor.submit(self._count_tokens_for_file, path)) + + return sum( + [ + future.result() + for future in tqdm( + as_completed(futures), + total=len(futures), + desc=f"Counting tokens {'for ' + source if source else ''}", + ) + ] + ) + + def _count_tokens_for_file(self, path: PathOrStr) -> int: + return self._bytes_to_tokens(get_file_size(path), self.dtype) + + def _bytes_to_tokens(self, num_bytes: int, dtype: NumpyDatasetDType) -> int: + """ + Convert bytes to tokens based on the dtype. + """ + npdtype = dtype.as_np_dtype() + return num_bytes // npdtype(int(0)).itemsize diff --git a/src/olmo_core/data/types.py b/src/olmo_core/data/types.py new file mode 100644 index 00000000..f6cffb2b --- /dev/null +++ b/src/olmo_core/data/types.py @@ -0,0 +1,41 @@ +from typing import Type, Union + +import numpy as np + +from olmo_core.config import StrEnum + + +SupportedDType = Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] + + +class NumpyDatasetType(StrEnum): + """ + An enumeration of the different :class:`NumpyDatasetBase` implementations. + """ + + fsl = "fsl" + """ + Fixed sequenced length ➡️ :class:`NumpyFSLDataset`. + """ + + padded_fsl = "padded_fsl" + """ + Padded fixed sequence length ➡️ :class:`NumpyPaddedFSLDataset`. + """ + + vsl = "vsl" + """ + Variable sequenced length ➡️ :class:`NumpyVSLDataset`. + """ + + +class NumpyDatasetDType(StrEnum): + uint8 = "uint8" + uint16 = "uint16" + uint32 = "uint32" + uint64 = "uint64" + + def as_np_dtype( + self, + ) -> Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]]: + return getattr(np, str(self)) diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index 2a5293db..4834e1e5 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -411,6 +411,7 @@ def segment_documents_into_instances( indices_dtype: Union[ Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64] ] = np.uint32, + max_instances: Optional[int] = None, ) -> Tuple[int, int]: """ Segment documents into instances of at most ``sequence_length`` tokens. @@ -421,6 +422,8 @@ def segment_documents_into_instances( total_og_docs = 0 indices = [] for start_idx, end_idx in iter_document_indices(path, eos_token_id=eos_token_id, dtype=dtype): + if max_instances and len(indices) // 2 >= max_instances: + break total_og_docs += 1 length = end_idx - start_idx indices.append(start_idx) diff --git a/src/scripts/benchmark/data/mixture_dataset_bm.py b/src/scripts/benchmark/data/mixture_dataset_bm.py deleted file mode 100755 index e2ff2727..00000000 --- a/src/scripts/benchmark/data/mixture_dataset_bm.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Build a mixture dataset from a list of source datasets and benchmark it. -""" - -import logging -import os -import time -from tempfile import TemporaryDirectory - -import s3fs - -from olmo_core.data import NumpyDatasetDType -from olmo_core.data.mixture_dataset import SourceMixtureDatasetConfig, SourceMixtureConfig - -log = logging.getLogger(__name__) - - -def build_config(output_dir, processes) -> SourceMixtureDatasetConfig: - s3 = s3fs.S3FileSystem() - books = s3.glob("s3://ai2-llm/preprocessed/books/allenai_dolma2/*.npy") - dclm = s3.glob( - "s3://ai2-llm/preprocessed/dclm/text_openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train/allenai/dolma2-tokenizer/*.npy" - ) - - print(f"Found {len(books)} books files") - print(f"Found {len(dclm)} dclm files") - - return SourceMixtureDatasetConfig( - max_tokens=1_000_000_000, - source_configs=[ - SourceMixtureConfig( - source_name="books", - paths=[f"s3://{path}" for path in books], - max_repetition_ratio=1.0, - target_ratio=0.1, - ), - SourceMixtureConfig( - source_name="dclm", - paths=[f"s3://{path}" for path in dclm], - target_ratio=0.9, - ), - ], - dtype=NumpyDatasetDType.uint32, - output_dir=output_dir, - processes=processes, - seed=42, - dry_run=False, - ) - - -if __name__ == "__main__": - with TemporaryDirectory() as temp_dir: - processes = os.cpu_count() - # TODO: ADD DRY RUN TIME - print(f"Running with {processes} processes") - config_a = build_config(temp_dir, processes) - start_time = time.time() - dataset = config_a.build() - end_time = time.time() - print(f"Built dataset in {end_time - start_time:.2f} seconds") diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index 19e0bd19..2dd28e28 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -1,18 +1,49 @@ from pathlib import Path -from typing import List +from typing import List, Tuple import numpy as np from olmo_core.data import ( NumpyDatasetConfig, + NumpyFSLDatasetMixtureConfig, NumpyFSLDataset, NumpyPaddedFSLDataset, NumpyVSLDataset, TokenizerConfig, ) + +from olmo_core.aliases import PathOrStr +from olmo_core.data.types import NumpyDatasetDType +from olmo_core.data.source_mixture import SourceMixtureDatasetConfig, SourceMixtureConfig from olmo_core.data.utils import get_document_indices, write_document_indices +def _make_mmaps( + tmp_path: Path, + prefix: str, + num_files: int, + size: int, + dtype, + eos: int, + seq_length: int = 4, + seed: int = 42, +) -> List[Tuple[PathOrStr, List]]: + mmaps = [] + for i in range(num_files): + filepath = f"{tmp_path}/{prefix}_{i}.npy" + np.random.seed(seed) + data = np.random.randint(0, np.iinfo(dtype).max, size=size, dtype=dtype) + data = np.append( + np.insert(data, np.arange(seq_length + 1, len(data), seq_length), eos), eos + ) + mm = np.memmap(filepath, mode="w+", dtype=dtype, shape=(len(data),)) + mm[:] = data + mm.flush() + mmaps.append((Path(filepath), data)) + + return mmaps + + def test_numpy_fsl_dataset(tmp_path: Path): mmap1 = np.memmap(tmp_path / "mmap1.npy", mode="w+", dtype=np.uint16, shape=(16,)) mmap1[:] = list(range(16)) @@ -65,6 +96,59 @@ def test_numpy_padded_fsl_dataset(tmp_path: Path): assert len(ds) == 4 +def test_numpy_fsl_mixture_dataset(tmp_path: Path): + # NOTE: At very small token counts (10's of tokens) + # the take_ratio gets finicky so we test at small but real world-ish scale) + npdtype = np.uint16 + seed = 42 + mmap1 = _make_mmaps(tmp_path, "mmap1", 1, 20 * 1000, npdtype, eos=0, seed=seed) + mmap2 = _make_mmaps(tmp_path, "mmap2", 1, 20 * 1000, npdtype, eos=0, seed=seed) + + sequence_length = 4 + tokenizer = TokenizerConfig( + vocab_size=32_000, + eos_token_id=0, + pad_token_id=-1, + ) + + mixture_config = SourceMixtureDatasetConfig( + max_tokens=10_000, + sequence_length=sequence_length, + source_configs=[ + SourceMixtureConfig( + source_name="mmap1", + paths=[i[0] for i in mmap1], + target_ratio=0.8, + ), + SourceMixtureConfig( + source_name="mmap2", + paths=[i[0] for i in mmap2], + target_ratio=0.2, + ), + ], + dtype=NumpyDatasetDType.uint16, + processes=1, + seed=seed, + ) + + ds = NumpyFSLDatasetMixtureConfig( + source_mixture_config=mixture_config, + sequence_length=sequence_length, + tokenizer=tokenizer, + bust_index_cache=True, + include_instance_metadata=False, + ).build() + ds.prepare() + + expected = "68144f" + assert ds.fingerprint.endswith( + expected + ), f"Fingerprint mismatch, expected {expected}, got {ds.fingerprint[-6:]}...Do you need to update the expected fingerprint?" + assert ds[0]["input_ids"].tolist() == [56422, 24545, 15795, 52202] + assert ds.num_tokens == 10000 + assert len(ds) == 2500 + + def write_data_file(data: List[int], path: Path, dtype, eos_token_id: int): path.parent.mkdir(exist_ok=True, parents=True) mmap = np.memmap(path, mode="w+", dtype=dtype, shape=(len(data),)) diff --git a/src/test/data/mixture_dataset_test.py b/src/test/data/source_mixture_test.py similarity index 54% rename from src/test/data/mixture_dataset_test.py rename to src/test/data/source_mixture_test.py index 3a3c6fce..f035bb3c 100644 --- a/src/test/data/mixture_dataset_test.py +++ b/src/test/data/source_mixture_test.py @@ -8,7 +8,7 @@ from olmo_core.aliases import PathOrStr from olmo_core.data import NumpyDatasetDType -from olmo_core.data.mixture_dataset import ( +from olmo_core.data.source_mixture import ( SourceMixtureConfig, SourceMixtureDataset, SourceMixtureDatasetConfig, @@ -35,7 +35,7 @@ def _make_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> List[ return mmaps -def test_source_mixture_config_dry_run(tmp_path: Path, capsys): +def test_source_mixture_config(tmp_path: Path, capsys): source_paths = { "1": _make_mmaps( tmp_path=tmp_path, prefix="source1", num_files=2, size=DATA["tokens_per_file"] @@ -64,19 +64,17 @@ def test_source_mixture_config_dry_run(tmp_path: Path, capsys): max_tokens = 5_000_000 - with TemporaryDirectory() as tmp_dir: - config = SourceMixtureDatasetConfig( - max_tokens=max_tokens, - source_configs=source_configs, - dtype=NumpyDatasetDType.uint32, - output_dir=tmp_dir, - dry_run=True, - ) + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) - with capsys.disabled(): - print("\n") - mixture = config.build() - assert isinstance(mixture, SourceMixtureDataset) + with capsys.disabled(): + print("\n") + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) def test_source_mixture_config_validation(): @@ -108,33 +106,28 @@ def test_dataset_mixture_config_validation(): SourceMixtureConfig(source_name="source2", target_ratio=0.5, paths=["/path/to/source2"]), ] - with TemporaryDirectory() as tmp_dir: - config = SourceMixtureDatasetConfig( - max_tokens=1000, - source_configs=source_configs, - dtype=NumpyDatasetDType.uint32, - output_dir=tmp_dir, - ) - config.validate() - - source_configs_invalid = [ - SourceMixtureConfig( - source_name="source1", target_ratio=0.7, paths=["/path/to/source1"] - ), - SourceMixtureConfig( - source_name="source2", target_ratio=0.5, paths=["/path/to/source2"] - ), - ] - - config_invalid = SourceMixtureDatasetConfig( - max_tokens=1000, - source_configs=source_configs_invalid, - dtype=NumpyDatasetDType.uint32, - output_dir=tmp_dir, - ) - - with pytest.raises(OLMoConfigurationError): - config_invalid.validate() + config = SourceMixtureDatasetConfig( + max_tokens=1000, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + config.validate() + + source_configs_invalid = [ + SourceMixtureConfig(source_name="source1", target_ratio=0.7, paths=["/path/to/source1"]), + SourceMixtureConfig(source_name="source2", target_ratio=0.5, paths=["/path/to/source2"]), + ] + + config_invalid = SourceMixtureDatasetConfig( + max_tokens=1000, + source_configs=source_configs_invalid, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + with pytest.raises(OLMoConfigurationError): + config_invalid.validate() def test_dataset_mixture_build(tmp_path: Path): @@ -166,16 +159,15 @@ def test_dataset_mixture_build(tmp_path: Path): max_tokens = 5_000_000 - with TemporaryDirectory() as tmp_dir: - config = SourceMixtureDatasetConfig( - max_tokens=max_tokens, - source_configs=source_configs, - dtype=DATA["dtype"], - output_dir=tmp_dir, - ) + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=DATA["dtype"], + sequence_length=1024, + ) - mixture = config.build() - assert isinstance(mixture, SourceMixtureDataset) + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) def test_dataset_mixture_build_insufficient_source_data(tmp_path: Path): @@ -206,17 +198,16 @@ def test_dataset_mixture_build_insufficient_source_data(tmp_path: Path): max_tokens = 5_000_000 - with TemporaryDirectory() as tmp_dir: - config = SourceMixtureDatasetConfig( - max_tokens=max_tokens, - source_configs=source_configs, - dtype=DATA["dtype"], - output_dir=tmp_dir, - ) + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=DATA["dtype"], + sequence_length=1024, + ) - # Should raise exception because the target ratio for source 1 @50% (2.5M) is infeasible without repetition (default max_repetition_ratio=1) - with pytest.raises(OLMoConfigurationError): - config.build() + # Should raise exception because the target ratio for source 1 @50% (2.5M) is infeasible without repetition (default max_repetition_ratio=1) + with pytest.raises(OLMoConfigurationError): + config.build() def test_dataset_mixture_build_with_repetition(tmp_path: Path): @@ -254,16 +245,15 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): max_tokens = 5_000_000 - with TemporaryDirectory() as tmp_dir: - config = SourceMixtureDatasetConfig( - max_tokens=max_tokens, - source_configs=source_configs, - dtype=DATA["dtype"], - output_dir=tmp_dir, - ) + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=DATA["dtype"], + sequence_length=1024, + ) - mixture = config.build() - assert isinstance(mixture, SourceMixtureDataset) + mixture = config.build() + assert isinstance(mixture, SourceMixtureDataset) def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): @@ -300,120 +290,14 @@ def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): # 5 source files * 1_000_000 tokens per file max_tokens = len(list(chain(*source_paths.values()))) * DATA["tokens_per_file"] - with TemporaryDirectory() as tmp_dir: - config = SourceMixtureDatasetConfig( - max_tokens=max_tokens, - source_configs=source_configs, - dtype=DATA["dtype"], - output_dir=tmp_dir, - ) - - # Should raise exception because the target ratio for source 1 is infeasible because - # we limit usage to 10% of the source - with pytest.raises(OLMoConfigurationError): - config.build() - - -def test_dataset_mixture_build_expected_files(tmp_path: Path): - source_paths = { - "1": _make_mmaps( - tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] - ), - "2": _make_mmaps( - tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] - ), - "3": _make_mmaps( - tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] - ), - } - source_configs = [ - SourceMixtureConfig( - source_name="1", - target_ratio=0.10, - paths=source_paths["1"], - ), - SourceMixtureConfig( - source_name="2", - target_ratio=0.40, - paths=source_paths["2"], - ), - SourceMixtureConfig( - source_name="3", - target_ratio=0.5, - paths=source_paths["3"], - ), - ] - - max_tokens = 10 * 1000 - - with TemporaryDirectory() as tmp_dir: - config = SourceMixtureDatasetConfig( - max_tokens=max_tokens, - source_configs=source_configs, - dtype=DATA["dtype"], - output_dir=tmp_dir, - ) - - mixture = config.build() - assert isinstance(mixture, SourceMixtureDataset) - - out_tokens = [] - - for source in mixture.sources: - for path in source.paths: - out_tokens.extend( - load_array_slice( - path=path, - start_idx=0, - end_idx=DATA["tokens_per_file"], - dtype=DATA["dtype"].as_np_dtype(), - ) - ) - - assert len(out_tokens) == max_tokens - - -def test_dataset_mixture_render_table(tmp_path: Path, capsys): - source_paths = { - "1": _make_mmaps( - tmp_path=tmp_path, prefix="source1", num_files=5, size=DATA["tokens_per_file"] - ), - "2": _make_mmaps( - tmp_path=tmp_path, prefix="source2", num_files=5, size=DATA["tokens_per_file"] - ), - "3": _make_mmaps( - tmp_path=tmp_path, prefix="source3", num_files=5, size=DATA["tokens_per_file"] - ), - } - source_configs = [ - SourceMixtureConfig( - source_name="1", - target_ratio=0.30, - paths=source_paths["1"], - ), - SourceMixtureConfig( - source_name="2", - target_ratio=0.40, - paths=source_paths["2"], - ), - SourceMixtureConfig( - source_name="3", - target_ratio=0.30, - paths=source_paths["3"], - ), - ] - - max_tokens = 10_123_000 - - with TemporaryDirectory() as tmp_dir: - config = SourceMixtureDatasetConfig( - max_tokens=max_tokens, - source_configs=source_configs, - dtype=DATA["dtype"], - output_dir=tmp_dir, - ) + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=DATA["dtype"], + sequence_length=1024, + ) - with capsys.disabled(): - print("\n") - mixture = config.build() - assert isinstance(mixture, SourceMixtureDataset) + # Should raise exception because the target ratio for source 1 is infeasible because + # we limit usage to 10% of the source + with pytest.raises(OLMoConfigurationError): + config.build() From efe766b1014912ace44499c395ceb61d86e83722 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 15:32:11 -0700 Subject: [PATCH 07/57] Launch script --- src/examples/train_with_mixture.py | 2 +- src/examples/train_with_mixture_launch.py | 41 +++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 src/examples/train_with_mixture_launch.py diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 3d826142..e30b86ea 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -3,7 +3,7 @@ Launch this with torchrun: - torchrun --nproc-per-node=4 src/examples/train.py run_name [OVERRIDES...] + torchrun --nproc-per-node=4 src/examples/train_with_mixture.py run_name [OVERRIDES...] """ import sys diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py new file mode 100644 index 00000000..4298f25f --- /dev/null +++ b/src/examples/train_with_mixture_launch.py @@ -0,0 +1,41 @@ +""" +An example of how to launch the training script on Beaker. +Run this with: + + python src/examples/train_with_mixture_launch.py run_name [OVERRIDES...] +""" + +import sys +from typing import List + +from olmo_core.launch.beaker import BeakerLaunchConfig +from olmo_core.utils import generate_uuid, prepare_cli_environment + + +def build_config(run_name: str, overrides: List[str]) -> BeakerLaunchConfig: + return BeakerLaunchConfig( + name=f"olmo-core-test-{generate_uuid()[:8]}", + budget="ai2/oe-training", + cmd=["src/examples/train_with_mixture.py", run_name, *overrides], + task_name="train", + workspace="ai2/OLMo-core", + description="Testing OLMo-core launch utilities", + clusters=["ai2/allennlp-elanding-a100-40g"], + num_nodes=1, + num_gpus=4, + shared_filesystem=True, + nfs=True, + allow_dirty=True, + ) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} run_name [OVERRIDES...]") + sys.exit(1) + + run_name, *overrides = sys.argv[1:] + + prepare_cli_environment() + + build_config(run_name, overrides).launch(follow=True) From 5dff40c6de7788441fb1f9c5f0ef666127e6ffe9 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 15:41:11 -0700 Subject: [PATCH 08/57] temp changes to test --- pyproject.toml | 1 + src/examples/train_with_mixture.py | 4 +++- src/examples/train_with_mixture_launch.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 95ec06a8..5f5c7259 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "omegaconf", "safetensors", "importlib_resources", + "s3fs", # REMOVE THIS IN FAVOR OF SOMETHING CONSISTENT ELSEWHERE ] [project.urls] diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index e30b86ea..a1e1a55e 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -168,7 +168,9 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: "evaluator", LMEvaluatorCallbackConfig( eval_dataset=NumpyDatasetConfig( - paths=["/net/nfs/allennlp/llm-data/c4/en/c4-validation.00000-00008.npy"], + paths=[ + "s3://ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy" + ], metadata=[{"label": "c4-validation"}], name=NumpyDatasetType.padded_fsl, sequence_length=sequence_length, diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py index 4298f25f..d850ee3d 100644 --- a/src/examples/train_with_mixture_launch.py +++ b/src/examples/train_with_mixture_launch.py @@ -24,7 +24,7 @@ def build_config(run_name: str, overrides: List[str]) -> BeakerLaunchConfig: num_nodes=1, num_gpus=4, shared_filesystem=True, - nfs=True, + nfs=False, allow_dirty=True, ) From 2703538c6daeb9238799bca105a7ac884f826a56 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 15:42:53 -0700 Subject: [PATCH 09/57] deps for now --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 5f5c7259..eed6d5bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,8 @@ dependencies = [ "safetensors", "importlib_resources", "s3fs", # REMOVE THIS IN FAVOR OF SOMETHING CONSISTENT ELSEWHERE + "tabulate", + "tqdm", ] [project.urls] From 3ee32787aaeacedf43f10612611d7c8fa113416a Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 15:54:01 -0700 Subject: [PATCH 10/57] Try with session --- src/examples/train_with_mixture.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index a1e1a55e..30ef77d1 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from typing import List, cast, Union +from aiobotocore.session import get_session import s3fs from olmo_core.config import Config, DType @@ -75,7 +76,9 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: ], ) - s3 = s3fs.S3FileSystem() + session = get_session() + client = session.create_client("s3", region_name="us-east-1", profile_name="S3") + s3 = s3fs.S3FileSystem(client=client) # DCLM docs baseline = s3.glob( From cd1c6d2b5c5d149568e4d32749c0640c447c1855 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 16:12:55 -0700 Subject: [PATCH 11/57] Try internal client --- src/examples/train_with_mixture.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 30ef77d1..36941b78 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -10,7 +10,6 @@ from dataclasses import dataclass from typing import List, cast, Union -from aiobotocore.session import get_session import s3fs from olmo_core.config import Config, DType @@ -21,6 +20,7 @@ NumpyDatasetType, TokenizerConfig, ) +from olmo_core.io import _get_s3_client from olmo_core.data.types import NumpyDatasetDType from olmo_core.data.source_mixture import SourceMixtureConfig, SourceMixtureDatasetConfig from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType @@ -76,9 +76,8 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: ], ) - session = get_session() - client = session.create_client("s3", region_name="us-east-1", profile_name="S3") - s3 = s3fs.S3FileSystem(client=client) + session = _get_s3_client("s3") + s3 = s3fs.S3FileSystem(session=session) # DCLM docs baseline = s3.glob( From 9895b231e3619f7cd88ccd34d10df9e73272820d Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 16:59:27 -0700 Subject: [PATCH 12/57] Try boto3 --- src/examples/train_with_mixture.py | 3 ++- src/olmo_core/launch/beaker.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 36941b78..e742c1ce 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -11,6 +11,7 @@ from typing import List, cast, Union import s3fs +import boto3 from olmo_core.config import Config, DType from olmo_core.data import ( @@ -76,7 +77,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: ], ) - session = _get_s3_client("s3") + session = boto3.Session(profile_name="S3_PROFILE") s3 = s3fs.S3FileSystem(session=session) # DCLM docs diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 8ee891cb..df935359 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -213,7 +213,7 @@ def default_env_vars(self) -> List[Tuple[str, str]]: (LOG_FILTER_TYPE_ENV_VAR, LogFilterType.local_rank0_only), ("OMP_NUM_THREADS", "8"), ("R2_PROFILE", "R2"), - ("S3_PROFILE", "S3"), + # ("S3_PROFILE", "S3"), ("WEKA_PROFILE", "WEKA"), ("NUM_NODES", str(self.num_nodes)), ("OLMO_CORE_VERSION", VERSION), From 3c15f52248d6326b5fe6cab8ec59921af620a85d Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 17:03:56 -0700 Subject: [PATCH 13/57] Fixes --- src/examples/train_with_mixture.py | 2 +- src/olmo_core/launch/beaker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index e742c1ce..49caca02 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -77,7 +77,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: ], ) - session = boto3.Session(profile_name="S3_PROFILE") + session = boto3.Session(profile_name="S3") s3 = s3fs.S3FileSystem(session=session) # DCLM docs diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index df935359..8ee891cb 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -213,7 +213,7 @@ def default_env_vars(self) -> List[Tuple[str, str]]: (LOG_FILTER_TYPE_ENV_VAR, LogFilterType.local_rank0_only), ("OMP_NUM_THREADS", "8"), ("R2_PROFILE", "R2"), - # ("S3_PROFILE", "S3"), + ("S3_PROFILE", "S3"), ("WEKA_PROFILE", "WEKA"), ("NUM_NODES", str(self.num_nodes)), ("OLMO_CORE_VERSION", VERSION), From 0c9355b8010165b3e51c462dcca71ac1e1b3cd41 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 23 Oct 2024 17:05:17 -0700 Subject: [PATCH 14/57] ? --- src/olmo_core/launch/beaker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 8ee891cb..df935359 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -213,7 +213,7 @@ def default_env_vars(self) -> List[Tuple[str, str]]: (LOG_FILTER_TYPE_ENV_VAR, LogFilterType.local_rank0_only), ("OMP_NUM_THREADS", "8"), ("R2_PROFILE", "R2"), - ("S3_PROFILE", "S3"), + # ("S3_PROFILE", "S3"), ("WEKA_PROFILE", "WEKA"), ("NUM_NODES", str(self.num_nodes)), ("OLMO_CORE_VERSION", VERSION), From abb362a5263c0200cc8244c7e48dcdce2afa8ac2 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 10:29:13 -0700 Subject: [PATCH 15/57] Cleanup + session stuff --- src/examples/train_with_mixture.py | 9 +++++---- src/olmo_core/data/numpy_dataset.py | 18 +++++++++--------- src/olmo_core/launch/beaker.py | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 49caca02..ac25fb62 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -77,10 +77,11 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: ], ) - session = boto3.Session(profile_name="S3") + # TODO: Maybe move the globbing into SourceMixtureConfig? + session = _get_s3_client("s3") s3 = s3fs.S3FileSystem(session=session) - # DCLM docs + # DCLM docs + rewrites baseline = s3.glob( "s3://ai2-llm/preprocessed/dclm/samples/src-100b/**/allenai/dolma2-tokenizer/*.npy" ) @@ -97,12 +98,12 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: paths=[f"s3://{path}" for path in baseline], source_name="baseline", max_repetition_ratio=1.0, - target_ratio=0.7, + target_ratio=0.8, ), SourceMixtureConfig( source_name="rewrites", paths=[f"s3://{path}" for path in rewrites], - target_ratio=0.3, + target_ratio=0.2, ), ], dtype=NumpyDatasetDType.uint32, diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 5f6c6ae3..5318e098 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -1643,7 +1643,14 @@ def __init__( self._path_offset_index = path_offset_index self._bust_index_cache = bust_index_cache - # TODO: overload __getitem__ to read the stuff we need, maybe just with read_chunk_from_array + # TODO: overload __getitem__ to read the stuff we need, maybe just with read_chunk_from_array?? + + def prepare(self): + if self.fs_local_rank == 0: + log.info("Gathering indices...") + self._write_document_indices() + barrier() + len(self) def _get_indices_path(self, path: PathOrStr) -> Path: sha256_hash = hashlib.sha256() @@ -1696,13 +1703,6 @@ def _write_document_indices(self): f"{self.sequence_length} from '{path}'" ) - def prepare(self): - if self.fs_local_rank == 0: - log.info("Gathering indices...") - self._write_document_indices() - barrier() - len(self) - def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor: start_idx = index * self.sequence_length return load_array_slice_into_tensor( @@ -1743,7 +1743,7 @@ class NumpyFSLDatasetMixtureConfig(Config): """ A config class for easily building :class:`NumpyFSLDatasetMixture` class. This is a special case of :class:`NumpyFSLDataset` that is built from a mixture of source - datasets based on a mixture configuration. + datasets based on a source mixture configuration. """ source_mixture_config: SourceMixtureDatasetConfig diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index df935359..8ee891cb 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -213,7 +213,7 @@ def default_env_vars(self) -> List[Tuple[str, str]]: (LOG_FILTER_TYPE_ENV_VAR, LogFilterType.local_rank0_only), ("OMP_NUM_THREADS", "8"), ("R2_PROFILE", "R2"), - # ("S3_PROFILE", "S3"), + ("S3_PROFILE", "S3"), ("WEKA_PROFILE", "WEKA"), ("NUM_NODES", str(self.num_nodes)), ("OLMO_CORE_VERSION", VERSION), From 82a1af9601a02ca0edd9b646501f45267ecf76c7 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 10:39:40 -0700 Subject: [PATCH 16/57] Use environ --- src/examples/train_with_mixture.py | 6 +++++- src/examples/train_with_mixture_launch.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index ac25fb62..d96292e2 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -6,6 +6,7 @@ torchrun --nproc-per-node=4 src/examples/train_with_mixture.py run_name [OVERRIDES...] """ +import os import sys from dataclasses import dataclass from typing import List, cast, Union @@ -78,7 +79,10 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: ) # TODO: Maybe move the globbing into SourceMixtureConfig? - session = _get_s3_client("s3") + session = boto3.Session( + aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), + ) s3 = s3fs.S3FileSystem(session=session) # DCLM docs + rewrites diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py index d850ee3d..ceaf0b32 100644 --- a/src/examples/train_with_mixture_launch.py +++ b/src/examples/train_with_mixture_launch.py @@ -8,7 +8,7 @@ import sys from typing import List -from olmo_core.launch.beaker import BeakerLaunchConfig +from olmo_core.launch.beaker import BeakerLaunchConfig, BeakerEnvSecret from olmo_core.utils import generate_uuid, prepare_cli_environment @@ -21,6 +21,10 @@ def build_config(run_name: str, overrides: List[str]) -> BeakerLaunchConfig: workspace="ai2/OLMo-core", description="Testing OLMo-core launch utilities", clusters=["ai2/allennlp-elanding-a100-40g"], + env_secrets=[ + BeakerEnvSecret("AWS_CREDENTIALS", "AWS_CREDENTIALS"), + BeakerEnvSecret("AWS_CONFIG", "AWS_CONFIG"), + ], num_nodes=1, num_gpus=4, shared_filesystem=True, From d0a80ba0aa41fd092cdcc31c76cb3ee2e09d7b74 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 10:46:45 -0700 Subject: [PATCH 17/57] JUST use env vars please boto --- src/examples/train_with_mixture.py | 7 +------ src/examples/train_with_mixture_launch.py | 4 ++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index d96292e2..d149073d 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -78,12 +78,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: ], ) - # TODO: Maybe move the globbing into SourceMixtureConfig? - session = boto3.Session( - aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), - aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), - ) - s3 = s3fs.S3FileSystem(session=session) + s3 = s3fs.S3FileSystem() # DCLM docs + rewrites baseline = s3.glob( diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py index ceaf0b32..3c717e38 100644 --- a/src/examples/train_with_mixture_launch.py +++ b/src/examples/train_with_mixture_launch.py @@ -22,8 +22,8 @@ def build_config(run_name: str, overrides: List[str]) -> BeakerLaunchConfig: description="Testing OLMo-core launch utilities", clusters=["ai2/allennlp-elanding-a100-40g"], env_secrets=[ - BeakerEnvSecret("AWS_CREDENTIALS", "AWS_CREDENTIALS"), - BeakerEnvSecret("AWS_CONFIG", "AWS_CONFIG"), + BeakerEnvSecret("AWS_ACCESS_KEY_ID", "AWS_ACCESS_KEY_ID"), + BeakerEnvSecret("AWS_SECRET_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY"), ], num_nodes=1, num_gpus=4, From e621f8e0447c3f61d2e571f7807cb2665c8e71df Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 10:48:48 -0700 Subject: [PATCH 18/57] No unions of containers --- src/examples/train_with_mixture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index d149073d..9845a068 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -54,7 +54,7 @@ class ExperimentConfig(Config): model: TransformerConfig optim: AdamWConfig - dataset: Union[NumpyDatasetConfig, NumpyFSLDatasetMixtureConfig] + dataset: NumpyFSLDatasetMixtureConfig data_loader: NumpyDataLoaderConfig trainer: TrainerConfig init_seed: int = 12536 From 0689c424338f70e050d82c65be773da134ab22ff Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:02:37 -0700 Subject: [PATCH 19/57] prepare first --- src/examples/train_with_mixture.py | 1 + src/olmo_core/data/numpy_dataset.py | 6 ------ 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 9845a068..0f0d924d 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -209,6 +209,7 @@ def main(run_name: str, overrides: List[str]): ) optim = config.optim.build(model) dataset = config.dataset.build() + dataset.prepare() data_loader = config.data_loader.build(dataset) trainer = config.trainer.build(model, optim, data_loader) diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 5318e098..e5f98dd5 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -1703,12 +1703,6 @@ def _write_document_indices(self): f"{self.sequence_length} from '{path}'" ) - def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor: - start_idx = index * self.sequence_length - return load_array_slice_into_tensor( - path, start_idx, start_idx + self.sequence_length, self.dtype - ) - def _get_file_size_and_length( self, path: PathOrStr, dtype: Optional[SupportedDType] = None ) -> Tuple[int, int]: From 8ab2e99430e16fbc394bdfdb2c0c0c76e64a2ada Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:03:39 -0700 Subject: [PATCH 20/57] Loader handles prepare --- src/examples/train_with_mixture.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 0f0d924d..9845a068 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -209,7 +209,6 @@ def main(run_name: str, overrides: List[str]): ) optim = config.optim.build(model) dataset = config.dataset.build() - dataset.prepare() data_loader = config.data_loader.build(dataset) trainer = config.trainer.build(model, optim, data_loader) From dcfda675ec5a05df5228130d6d11b1c6c1121b4d Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:13:52 -0700 Subject: [PATCH 21/57] Try recording torch exceptions --- src/examples/train_with_mixture.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 9845a068..44154a4f 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -6,13 +6,12 @@ torchrun --nproc-per-node=4 src/examples/train_with_mixture.py run_name [OVERRIDES...] """ -import os import sys from dataclasses import dataclass -from typing import List, cast, Union +from typing import List, cast import s3fs -import boto3 +from torch.distributed.elastic.multiprocessing.errors import record from olmo_core.config import Config, DType from olmo_core.data import ( @@ -22,7 +21,6 @@ NumpyDatasetType, TokenizerConfig, ) -from olmo_core.io import _get_s3_client from olmo_core.data.types import NumpyDatasetDType from olmo_core.data.source_mixture import SourceMixtureConfig, SourceMixtureDatasetConfig from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType @@ -172,7 +170,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: LMEvaluatorCallbackConfig( eval_dataset=NumpyDatasetConfig( paths=[ - "s3://ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy" + # "s3://ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy" ], metadata=[{"label": "c4-validation"}], name=NumpyDatasetType.padded_fsl, @@ -195,6 +193,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: ).merge(overrides) +@record def main(run_name: str, overrides: List[str]): config = build_config(run_name, overrides) From 23a0806a9d8fe0b11830f04aaf2bc5a152f95027 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:18:22 -0700 Subject: [PATCH 22/57] Don't need overrides --- src/examples/train_with_mixture.py | 18 +++++++++--------- src/examples/train_with_mixture_launch.py | 9 ++++----- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 44154a4f..38e274bc 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -3,12 +3,12 @@ Launch this with torchrun: - torchrun --nproc-per-node=4 src/examples/train_with_mixture.py run_name [OVERRIDES...] + torchrun --nproc-per-node=4 src/examples/train_with_mixture.py run_name """ import sys from dataclasses import dataclass -from typing import List, cast +from typing import cast import s3fs from torch.distributed.elastic.multiprocessing.errors import record @@ -58,7 +58,7 @@ class ExperimentConfig(Config): init_seed: int = 12536 -def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: +def build_config(run_name: str) -> ExperimentConfig: tokenizer_config = TokenizerConfig.gpt2() model_config = TransformerConfig.llama2_271M( @@ -190,12 +190,12 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: dataset=dataset_config, data_loader=data_loader_config, trainer=trainer_config, - ).merge(overrides) + ) @record -def main(run_name: str, overrides: List[str]): - config = build_config(run_name, overrides) +def main(run_name: str): + config = build_config(run_name) # Set RNG states on all devices. seed_all(config.init_seed) @@ -223,13 +223,13 @@ def main(run_name: str, overrides: List[str]): if __name__ == "__main__": if len(sys.argv) < 2: - print(f"Usage: python {sys.argv[0]} run_name [OVERRIDES...]") + print(f"Usage: python {sys.argv[0]} run_name") sys.exit(1) - run_name, *overrides = sys.argv[1:] + run_name = sys.argv[1] prepare_training_environment() try: - main(run_name, overrides=overrides) + main(run_name) finally: teardown_training_environment() diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py index 3c717e38..1b177458 100644 --- a/src/examples/train_with_mixture_launch.py +++ b/src/examples/train_with_mixture_launch.py @@ -6,17 +6,16 @@ """ import sys -from typing import List from olmo_core.launch.beaker import BeakerLaunchConfig, BeakerEnvSecret from olmo_core.utils import generate_uuid, prepare_cli_environment -def build_config(run_name: str, overrides: List[str]) -> BeakerLaunchConfig: +def build_config(run_name: str) -> BeakerLaunchConfig: return BeakerLaunchConfig( name=f"olmo-core-test-{generate_uuid()[:8]}", budget="ai2/oe-training", - cmd=["src/examples/train_with_mixture.py", run_name, *overrides], + cmd=["src/examples/train_with_mixture.py", run_name], task_name="train", workspace="ai2/OLMo-core", description="Testing OLMo-core launch utilities", @@ -38,8 +37,8 @@ def build_config(run_name: str, overrides: List[str]) -> BeakerLaunchConfig: print(f"Usage: python {sys.argv[0]} run_name [OVERRIDES...]") sys.exit(1) - run_name, *overrides = sys.argv[1:] + run_name = sys.argv[1] prepare_cli_environment() - build_config(run_name, overrides).launch(follow=True) + build_config(run_name).launch(follow=True) From 8cfa282352df4cba5622f5a43a417feb07888b79 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:31:45 -0700 Subject: [PATCH 23/57] Figure out why config/creds are missing --- src/examples/train_with_mixture.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 38e274bc..40134889 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -12,6 +12,7 @@ import s3fs from torch.distributed.elastic.multiprocessing.errors import record +from beaker import Beaker from olmo_core.config import Config, DType from olmo_core.data import ( @@ -195,6 +196,8 @@ def build_config(run_name: str) -> ExperimentConfig: @record def main(run_name: str): + beaker_user = (Beaker.from_env().account.whoami().name).upper() + print(f"Running as: {beaker_user}") config = build_config(run_name) # Set RNG states on all devices. From fd1a5083cd460d7c2e018a9844337d4056da6977 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:36:59 -0700 Subject: [PATCH 24/57] fmt --- src/olmo_core/internal/experiment.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 42d0ca1e..f649f5d2 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -180,9 +180,11 @@ def build_common_components( vsl_curriculum=VSLCurriculumConfig( name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False ), - work_dir=None - if is_url(root_dir) - else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache", + work_dir=( + None + if is_url(root_dir) + else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache" + ), ) data_loader_config = NumpyDataLoaderConfig( @@ -202,9 +204,11 @@ def build_common_components( mix_base_dir=root_dir, sequence_length=dataset_config.effective_sequence_length, tokenizer=tokenizer_config, - work_dir=None - if is_url(root_dir) - else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache", + work_dir=( + None + if is_url(root_dir) + else f"{root_dir}/checkpoints/{beaker_user.lower()}/dataset-cache" + ), ), eval_interval=1000, ), From 01a40ea92dbcf9145f02a752d5c34b62d1b76966 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:42:05 -0700 Subject: [PATCH 25/57] Env not ready yet --- src/examples/train_with_mixture.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 40134889..38e274bc 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -12,7 +12,6 @@ import s3fs from torch.distributed.elastic.multiprocessing.errors import record -from beaker import Beaker from olmo_core.config import Config, DType from olmo_core.data import ( @@ -196,8 +195,6 @@ def build_config(run_name: str) -> ExperimentConfig: @record def main(run_name: str): - beaker_user = (Beaker.from_env().account.whoami().name).upper() - print(f"Running as: {beaker_user}") config = build_config(run_name) # Set RNG states on all devices. From 8bde2b3b0ac46f64380bf2693a18aa7d2aae13ef Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:48:00 -0700 Subject: [PATCH 26/57] print beaker user --- src/examples/train_with_mixture.py | 2 +- src/olmo_core/internal/experiment.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 38e274bc..894abcd7 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -59,7 +59,7 @@ class ExperimentConfig(Config): def build_config(run_name: str) -> ExperimentConfig: - tokenizer_config = TokenizerConfig.gpt2() + tokenizer_config = TokenizerConfig.dolma2() model_config = TransformerConfig.llama2_271M( vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128 diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index f649f5d2..da9cc961 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -125,6 +125,7 @@ def build_common_components( weka_buckets.append(BeakerWekaBucket("oe-training-default", "/weka/oe-training-default")) beaker_user = (Beaker.from_env().account.whoami().name).upper() + print(f"Beaker user: {beaker_user}") cmd_to_launch = SubCmd.train if cmd == SubCmd.launch_prep: cmd_to_launch = SubCmd.prep From ae208f6ec14b86961294deff6cf1406a31bc5c6b Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 11:57:56 -0700 Subject: [PATCH 27/57] uncomment eval file --- src/examples/train_with_mixture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 894abcd7..2dc9dd07 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -170,7 +170,7 @@ def build_config(run_name: str) -> ExperimentConfig: LMEvaluatorCallbackConfig( eval_dataset=NumpyDatasetConfig( paths=[ - # "s3://ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy" + "s3://ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy" ], metadata=[{"label": "c4-validation"}], name=NumpyDatasetType.padded_fsl, From ce9d06f451598f71141812e0af776f2aabcd4678 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 12:37:35 -0700 Subject: [PATCH 28/57] replicate CommonComponents setup --- src/examples/train_with_mixture_launch.py | 26 +++++++++++++++++++++-- src/olmo_core/internal/experiment.py | 1 - 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py index 1b177458..29bdd2af 100644 --- a/src/examples/train_with_mixture_launch.py +++ b/src/examples/train_with_mixture_launch.py @@ -7,11 +7,14 @@ import sys +from beaker import Beaker + from olmo_core.launch.beaker import BeakerLaunchConfig, BeakerEnvSecret from olmo_core.utils import generate_uuid, prepare_cli_environment def build_config(run_name: str) -> BeakerLaunchConfig: + beaker_user = (Beaker.from_env().account.whoami().name).upper() return BeakerLaunchConfig( name=f"olmo-core-test-{generate_uuid()[:8]}", budget="ai2/oe-training", @@ -21,8 +24,27 @@ def build_config(run_name: str) -> BeakerLaunchConfig: description="Testing OLMo-core launch utilities", clusters=["ai2/allennlp-elanding-a100-40g"], env_secrets=[ - BeakerEnvSecret("AWS_ACCESS_KEY_ID", "AWS_ACCESS_KEY_ID"), - BeakerEnvSecret("AWS_SECRET_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY"), + BeakerEnvSecret(name="BEAKER_TOKEN", secret=f"{beaker_user}_BEAKER_TOKEN"), + BeakerEnvSecret(name="WANDB_API_KEY", secret=f"{beaker_user}_WANDB_API_KEY"), + BeakerEnvSecret(name="COMET_API_KEY", secret=f"{beaker_user}_COMET_API_KEY"), + BeakerEnvSecret(name="AWS_CONFIG", secret=f"{beaker_user}_AWS_CONFIG"), + BeakerEnvSecret(name="AWS_CREDENTIALS", secret=f"{beaker_user}_AWS_CREDENTIALS"), + BeakerEnvSecret(name="R2_ENDPOINT_URL", secret="R2_ENDPOINT_URL"), + BeakerEnvSecret(name="WEKA_ENDPOINT_URL", secret="WEKA_ENDPOINT_URL"), + ], + setup_steps=[ + # Clone repo. + 'git clone "$REPO_URL" .', + 'git checkout "$GIT_REF"', + "git submodule update --init --recursive", + # Setup python environment. + "conda shell.bash activate base", + "pip install -e '.[all]'", + "pip freeze", + # Move AWS credentials from env to relevant files + "mkdir -p ~/.aws", + "printenv AWS_CONFIG > ~/.aws/config", + "printenv AWS_CREDENTIALS > ~/.aws/credentials", ], num_nodes=1, num_gpus=4, diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index da9cc961..f649f5d2 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -125,7 +125,6 @@ def build_common_components( weka_buckets.append(BeakerWekaBucket("oe-training-default", "/weka/oe-training-default")) beaker_user = (Beaker.from_env().account.whoami().name).upper() - print(f"Beaker user: {beaker_user}") cmd_to_launch = SubCmd.train if cmd == SubCmd.launch_prep: cmd_to_launch = SubCmd.prep From d1eb4dff7aa25393c57879e68accb43d87698678 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 12:55:07 -0700 Subject: [PATCH 29/57] Some class init stuff --- src/olmo_core/data/numpy_dataset.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index e5f98dd5..8e000f7f 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -1622,6 +1622,25 @@ def __init__( max_target_sequence_length: Optional[int] = None, bust_index_cache: bool = False, ): + if max_target_sequence_length is not None and ( + max_target_sequence_length < sequence_length + or max_target_sequence_length % sequence_length != 0 + ): + raise OLMoConfigurationError( + "'max_target_sequence_length' should be a multiple of 'sequence_length'" + ) + + if include_instance_metadata is None and metadata: + include_instance_metadata = True + + if isinstance(metadata, list): + if len(metadata) != len(paths): + raise OLMoConfigurationError( + "'metadata' should have the same length as the number of file paths" + ) + else: + metadata = [metadata or {}] * len(paths) + super().__init__( *paths, pad_token_id=pad_token_id, @@ -1764,7 +1783,7 @@ class NumpyFSLDatasetMixtureConfig(Config): """ Metadata for the numpy arrays. """ - include_instance_metadata: bool = True + include_instance_metadata: bool = False """ Whether or not to include the :data:`metadata` in the instances returned from :meth:`NumpyDatasetBase.__getitem__()`. From dbce279d2441e984b3e93841e0c015783734a0d4 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 13:17:54 -0700 Subject: [PATCH 30/57] Some more config logging --- CHANGELOG.md | 10 ++++++++++ src/olmo_core/data/numpy_dataset.py | 2 -- src/olmo_core/data/source_mixture.py | 6 ++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5383ca29..d867e238 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `SourceMixtureDataset` for composing a training mixture based on ratios of source datasets. +- Added `NumpyFSLDatasetMixture` for constructing a `NumpyDatasetBase` from a `SourceMixtureDataset`. Note this is only supported for FSL datasets. +- Added tests for `SourceMixture*` and `NumpyFSLDatasetMixture`. +- Added example launch script for training a model using a `NumpyFSLDatasetMixture`. + +### Changed +- Moved some types into `olmo_core.data.types` to avoid some circular dependencies. + +### Added + - Added `CometCallback` for logging training runs to Comet.ml. - Added `DataMixBase` class, to allow extending to new data mix groups. - Added method `DataLoaderBase.get_mock_batch()`. diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 8e000f7f..dd42d926 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -1662,8 +1662,6 @@ def __init__( self._path_offset_index = path_offset_index self._bust_index_cache = bust_index_cache - # TODO: overload __getitem__ to read the stuff we need, maybe just with read_chunk_from_array?? - def prepare(self): if self.fs_local_rank == 0: log.info("Gathering indices...") diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py index cfb4d7b0..332643ec 100644 --- a/src/olmo_core/data/source_mixture.py +++ b/src/olmo_core/data/source_mixture.py @@ -5,6 +5,7 @@ from concurrent.futures import as_completed, ThreadPoolExecutor from dataclasses import dataclass from typing import Dict, List, Optional +from pprint import pprint import tabulate from tqdm import tqdm @@ -142,6 +143,11 @@ def build(self) -> SourceMixtureDataset: random.seed(self.seed) available_tokens_by_source: Dict[str, int] = {} + print("--------------------------------------------------------------------------------") + print("Generating a source mixture from configurations:") + for source_config in self.source_configs: + pprint(source_config) + # Count the number of tokens available for each source for source_config in self.source_configs: log.info(f"Counting tokens for source: {source_config.source_name}") From 980e05a1a441ddc81210135b5a099fa73f63ca72 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Thu, 24 Oct 2024 14:49:55 -0700 Subject: [PATCH 31/57] checks cleanup --- src/examples/train_with_mixture.py | 7 +- src/examples/train_with_mixture_launch.py | 2 +- src/olmo_core/data/__init__.py | 2 +- src/olmo_core/data/numpy_dataset.py | 6 +- src/olmo_core/data/source_mixture.py | 12 +-- src/olmo_core/data/types.py | 1 - src/olmo_core/data/utils.py | 2 +- src/test/data/numpy_dataset_test.py | 18 +++-- src/test/data/source_mixture_test.py | 89 +++++++---------------- 9 files changed, 56 insertions(+), 83 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 2dc9dd07..9a52fe27 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -17,12 +17,15 @@ from olmo_core.data import ( NumpyDataLoaderConfig, NumpyDatasetConfig, - NumpyFSLDatasetMixtureConfig, NumpyDatasetType, + NumpyFSLDatasetMixtureConfig, TokenizerConfig, ) +from olmo_core.data.source_mixture import ( + SourceMixtureConfig, + SourceMixtureDatasetConfig, +) from olmo_core.data.types import NumpyDatasetDType -from olmo_core.data.source_mixture import SourceMixtureConfig, SourceMixtureDatasetConfig from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType from olmo_core.distributed.utils import init_hybrid_shard_mesh from olmo_core.nn.transformer import TransformerConfig diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py index 29bdd2af..705d21bd 100644 --- a/src/examples/train_with_mixture_launch.py +++ b/src/examples/train_with_mixture_launch.py @@ -9,7 +9,7 @@ from beaker import Beaker -from olmo_core.launch.beaker import BeakerLaunchConfig, BeakerEnvSecret +from olmo_core.launch.beaker import BeakerEnvSecret, BeakerLaunchConfig from olmo_core.utils import generate_uuid, prepare_cli_environment diff --git a/src/olmo_core/data/__init__.py b/src/olmo_core/data/__init__.py index 3ddd0d23..07b4bfa5 100644 --- a/src/olmo_core/data/__init__.py +++ b/src/olmo_core/data/__init__.py @@ -24,8 +24,8 @@ from .numpy_dataset import ( NumpyDatasetBase, NumpyDatasetConfig, - NumpyFSLDatasetMixtureConfig, NumpyFSLDataset, + NumpyFSLDatasetMixtureConfig, NumpyPaddedFSLDataset, NumpyVSLDataset, VSLCurriculum, diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index dd42d926..45b27eae 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -29,9 +29,9 @@ import torch.nn.functional as F from torch.utils.data import Dataset +from olmo_core.data.source_mixture import SourceMixtureDatasetConfig +from olmo_core.data.types import NumpyDatasetDType, NumpyDatasetType, SupportedDType from olmo_core.exceptions import OLMoConfigurationError, OLMoEnvironmentError -from olmo_core.data.types import NumpyDatasetType, NumpyDatasetDType, SupportedDType -from olmo_core.data.source_mixture import SourceMixtureDatasetConfig, SourcePathTokens from ..aliases import PathOrStr from ..config import Config, StrEnum @@ -1653,7 +1653,7 @@ def __init__( generate_doc_lengths=generate_doc_lengths, max_target_sequence_length=max_target_sequence_length, ) - self._metadata = metadata + self._metadata = tuple(metadata) self._include_instance_metadata = include_instance_metadata self._num_instances: Optional[int] = None self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py index 332643ec..bc4096fa 100644 --- a/src/olmo_core/data/source_mixture.py +++ b/src/olmo_core/data/source_mixture.py @@ -1,11 +1,11 @@ import logging import math import random -from itertools import chain -from concurrent.futures import as_completed, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass -from typing import Dict, List, Optional +from itertools import chain from pprint import pprint +from typing import Dict, List, Optional import tabulate from tqdm import tqdm @@ -220,9 +220,9 @@ def build(self) -> SourceMixtureDataset: ), ) - for source in completed: - for item in source.path_tokens: - log.info(f"Selected {item.tokens} tokens from {source.name} at {item.path}") + for outcome in completed: + for item in outcome.path_tokens: + log.info(f"Selected {item.tokens} tokens from {outcome.name} at {item.path}") return SourceMixtureDataset(completed) diff --git a/src/olmo_core/data/types.py b/src/olmo_core/data/types.py index f6cffb2b..2b814b28 100644 --- a/src/olmo_core/data/types.py +++ b/src/olmo_core/data/types.py @@ -4,7 +4,6 @@ from olmo_core.config import StrEnum - SupportedDType = Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index 4834e1e5..7bc29ad0 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -420,7 +420,7 @@ def segment_documents_into_instances( Returns the number of original documents and the number of resulting instances documents. """ total_og_docs = 0 - indices = [] + indices: List[int] = [] for start_idx, end_idx in iter_document_indices(path, eos_token_id=eos_token_id, dtype=dtype): if max_instances and len(indices) // 2 >= max_instances: break diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index 2dd28e28..c1a2e3e3 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -1,22 +1,26 @@ +from os import PathLike from pathlib import Path -from typing import List, Tuple +from typing import Any, List, Tuple, Union import numpy as np from olmo_core.data import ( NumpyDatasetConfig, - NumpyFSLDatasetMixtureConfig, NumpyFSLDataset, + NumpyFSLDatasetMixtureConfig, NumpyPaddedFSLDataset, NumpyVSLDataset, TokenizerConfig, ) - -from olmo_core.aliases import PathOrStr +from olmo_core.data.source_mixture import ( + SourceMixtureConfig, + SourceMixtureDatasetConfig, +) from olmo_core.data.types import NumpyDatasetDType -from olmo_core.data.source_mixture import SourceMixtureDatasetConfig, SourceMixtureConfig from olmo_core.data.utils import get_document_indices, write_document_indices +Mmaps = List[Tuple[Union[Path, PathLike[Any], str], Any]] + def _make_mmaps( tmp_path: Path, @@ -27,8 +31,8 @@ def _make_mmaps( eos: int, seq_length: int = 4, seed: int = 42, -) -> List[Tuple[PathOrStr, List]]: - mmaps = [] +) -> Mmaps: + mmaps: Mmaps = [] for i in range(num_files): filepath = f"{tmp_path}/{prefix}_{i}.npy" np.random.seed(seed) diff --git a/src/test/data/source_mixture_test.py b/src/test/data/source_mixture_test.py index f035bb3c..adbf1af7 100644 --- a/src/test/data/source_mixture_test.py +++ b/src/test/data/source_mixture_test.py @@ -1,33 +1,30 @@ from itertools import chain +from os import PathLike from pathlib import Path -from tempfile import TemporaryDirectory -from typing import List +from typing import Any, List, Union import numpy as np import pytest -from olmo_core.aliases import PathOrStr from olmo_core.data import NumpyDatasetDType from olmo_core.data.source_mixture import ( SourceMixtureConfig, SourceMixtureDataset, SourceMixtureDatasetConfig, ) -from olmo_core.data.utils import load_array_slice from olmo_core.exceptions import OLMoConfigurationError -DATA = { - "dtype": NumpyDatasetDType.uint32, - "tokens_per_file": 1_000_000, -} +Mmaps = List[Union[Path, PathLike[Any], str]] -def _make_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> List[PathOrStr]: - mmaps = [] +def _make_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> Mmaps: + mmaps: Mmaps = [] for i in range(num_files): filepath = f"{tmp_path}/{prefix}_{i}.npy" data = np.random.randint(0, 2**32, size=size, dtype=np.uint32) - mm = np.memmap(filepath, mode="w+", dtype=DATA["dtype"].as_np_dtype(), shape=(size,)) + mm = np.memmap( + filepath, mode="w+", dtype=NumpyDatasetDType.uint32.as_np_dtype(), shape=(size,) + ) mm[:] = data mm.flush() mmaps.append(Path(filepath)) @@ -37,15 +34,9 @@ def _make_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> List[ def test_source_mixture_config(tmp_path: Path, capsys): source_paths = { - "1": _make_mmaps( - tmp_path=tmp_path, prefix="source1", num_files=2, size=DATA["tokens_per_file"] - ), - "2": _make_mmaps( - tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] - ), - "3": _make_mmaps( - tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] - ), + "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), + "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ @@ -132,15 +123,9 @@ def test_dataset_mixture_config_validation(): def test_dataset_mixture_build(tmp_path: Path): source_paths = { - "1": _make_mmaps( - tmp_path=tmp_path, prefix="source1", num_files=2, size=DATA["tokens_per_file"] - ), - "2": _make_mmaps( - tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] - ), - "3": _make_mmaps( - tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] - ), + "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), + "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ @@ -162,7 +147,7 @@ def test_dataset_mixture_build(tmp_path: Path): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=DATA["dtype"], + dtype=NumpyDatasetDType.uint32, sequence_length=1024, ) @@ -172,15 +157,9 @@ def test_dataset_mixture_build(tmp_path: Path): def test_dataset_mixture_build_insufficient_source_data(tmp_path: Path): source_paths = { - "1": _make_mmaps( - tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] - ), - "2": _make_mmaps( - tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] - ), - "3": _make_mmaps( - tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] - ), + "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ SourceMixtureConfig( @@ -201,7 +180,7 @@ def test_dataset_mixture_build_insufficient_source_data(tmp_path: Path): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=DATA["dtype"], + dtype=NumpyDatasetDType.uint32, sequence_length=1024, ) @@ -217,15 +196,9 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): Source 1 has a target ratio of 90% and a max repetition ratio of 4.0, so it should be possible to meet the target of 3600 tokens with 1 file of 1000 tokens repeated 4 times. """ source_paths = { - "1": _make_mmaps( - tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] - ), - "2": _make_mmaps( - tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] - ), - "3": _make_mmaps( - tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] - ), + "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ @@ -248,7 +221,7 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=DATA["dtype"], + dtype=NumpyDatasetDType.uint32, sequence_length=1024, ) @@ -258,15 +231,9 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): source_paths = { - "1": _make_mmaps( - tmp_path=tmp_path, prefix="source1", num_files=1, size=DATA["tokens_per_file"] - ), - "2": _make_mmaps( - tmp_path=tmp_path, prefix="source2", num_files=2, size=DATA["tokens_per_file"] - ), - "3": _make_mmaps( - tmp_path=tmp_path, prefix="source3", num_files=2, size=DATA["tokens_per_file"] - ), + "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ SourceMixtureConfig( @@ -288,12 +255,12 @@ def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): ] # 5 source files * 1_000_000 tokens per file - max_tokens = len(list(chain(*source_paths.values()))) * DATA["tokens_per_file"] + max_tokens = len(list(chain(*source_paths.values()))) * 1_000_000 config = SourceMixtureDatasetConfig( max_tokens=max_tokens, source_configs=source_configs, - dtype=DATA["dtype"], + dtype=NumpyDatasetDType.uint32, sequence_length=1024, ) From f27bd734a92e89fec21615af658c6dbfd3405cf7 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 11:07:45 -0700 Subject: [PATCH 32/57] Fixes for duplicate paths in mixture --- src/examples/train_with_mixture.py | 10 ++-- src/olmo_core/data/numpy_dataset.py | 54 ++++++++++++--------- src/olmo_core/data/source_mixture.py | 51 ++++++++++++++++---- src/olmo_core/data/utils.py | 3 +- src/test/data/numpy_dataset_test.py | 71 ++++++++++++++++++++++++++-- src/test/data/source_mixture_test.py | 48 +++++++++++++++++++ 6 files changed, 194 insertions(+), 43 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 9a52fe27..41856757 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -1,5 +1,5 @@ """ -Example of how to train a transformer language model. +Example of how to train a transformer language model with a source mixture config. Launch this with torchrun: @@ -91,13 +91,13 @@ def build_config(run_name: str) -> ExperimentConfig: sequence_length = 1024 source_config = SourceMixtureDatasetConfig( - max_tokens=20_000_000, + max_tokens=int(10e8), # 100M tokens sequence_length=sequence_length, source_configs=[ SourceMixtureConfig( paths=[f"s3://{path}" for path in baseline], source_name="baseline", - max_repetition_ratio=1.0, + max_repetition_ratio=1.0, # 1.0 is a no-op but added here to illustrate the option target_ratio=0.8, ), SourceMixtureConfig( @@ -120,7 +120,7 @@ def build_config(run_name: str) -> ExperimentConfig: ) data_loader_config = NumpyDataLoaderConfig( - global_batch_size=256 * 1024, + global_batch_size=256 * sequence_length, seed=0, num_workers=4, ) @@ -128,7 +128,7 @@ def build_config(run_name: str) -> ExperimentConfig: trainer_config = ( TrainerConfig( save_folder=f"/tmp/{run_name}", - rank_microbatch_size=16 * 1024, + rank_microbatch_size=16 * sequence_length, save_overwrite=True, metrics_collect_interval=5, cancel_check_interval=5, diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 45b27eae..4b0d923d 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -136,7 +136,7 @@ def file_sizes(self) -> Tuple[int, ...]: The size, in bytes, of each numpy array. """ if self._array_file_sizes is None: - self._array_file_sizes = tuple(self.map(get_file_size)) + self._array_file_sizes = tuple(self.map(lambda item: get_file_size(item[0]))) return self._array_file_sizes @property @@ -236,7 +236,7 @@ def _warmup_clients(self): def map( self, - func: Callable[[PathOrStr], T], + func: Callable[[Tuple[PathOrStr, int]], T], *, max_workers: Optional[int] = None, method: Literal["threads", "processes"] = "threads", @@ -255,7 +255,7 @@ def map( paths = _paths or self.paths if max_workers == 0: - return [func(path) for path in paths] + return [func((path, idx)) for idx, path in enumerate(paths)] executor_class: Union[ Type[concurrent.futures.ThreadPoolExecutor], @@ -271,9 +271,9 @@ def map( with executor_class(max_workers=max_workers) as executor: path_to_future = {} - for path in paths: + for idx, path in enumerate(paths): if path not in path_to_future: - path_to_future[path] = executor.submit(func, path) + path_to_future[path] = executor.submit(func, (path, idx)) results = [] for path in paths: @@ -484,7 +484,10 @@ def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor: path, start_idx, start_idx + self.sequence_length, self.dtype ) - def _get_file_size_and_length(self, path, dtype=None) -> Tuple[int, int]: + def _get_file_size_and_length( + self, item: Tuple[PathOrStr, int], dtype: Optional[SupportedDType] = None + ) -> Tuple[int, int]: + path, _ = item dtype = dtype or self.dtype item_size = dtype(0).itemsize file_size = get_file_size(path) @@ -538,7 +541,8 @@ def offsets(self) -> Tuple[Tuple[int, int], ...]: if self._array_instance_offsets is None: item_size = self.indices_dtype(0).itemsize num_instances_per_path = self.map( - lambda path: get_file_size(self._get_instance_indices_path(path)) // (item_size * 2) + lambda item: get_file_size(self._get_instance_indices_path(item[0])) + // (item_size * 2) ) array_instance_offsets = [] start_offset = 0 @@ -1610,7 +1614,7 @@ class NumpyFSLDatasetMixture(NumpyFSLDataset): def __init__( self, *paths: PathOrStr, - path_offset_index: Dict[str, int], + path_offset_index: Dict[Tuple[str, int], int], sequence_length: int, pad_token_id: int, eos_token_id: int, @@ -1681,22 +1685,24 @@ def _get_indices_path(self, path: PathOrStr) -> Path: ) def _write_document_indices(self): - paths_needed: List[PathOrStr] = [] - for path in self.paths: + paths_needed: List[Tuple[PathOrStr, int]] = [] + for idx, path in enumerate(self.paths): indices_path = self._get_indices_path(path) if indices_path.is_file() and not self._bust_index_cache: log.info(f"Reusing document indices for '{path}' at:\n'{indices_path}'") elif path not in paths_needed: - paths_needed.append(path) + paths_needed.append((path, idx)) if paths_needed: with concurrent.futures.ProcessPoolExecutor() as executor: futures = [] - for path in paths_needed: + for path, idx in paths_needed: indices_path = self._get_indices_path(path) log.info(f"Gathering instance indices for '{path}'...") - # NOTE: We limit the number of instances to the number by total target token count - max_instances = self._path_offset_index[str(path)] // self.sequence_length + # NOTE: We limit the number of instances by total target token count // sequence length + max_instances = ( + self._path_offset_index[(str(path), idx)] // self.sequence_length + ) future = executor.submit( run_worker_func, segment_documents_into_instances, @@ -1713,7 +1719,7 @@ def _write_document_indices(self): concurrent.futures.wait(futures, return_when="ALL_COMPLETED") # Log results. - for path, future in zip(paths_needed, futures): + for path, future in zip([item[0] for item in paths_needed], futures): _, total_instances = future.result() log.info( f"Created {total_instances:,d} instances of sequence length up to " @@ -1721,11 +1727,12 @@ def _write_document_indices(self): ) def _get_file_size_and_length( - self, path: PathOrStr, dtype: Optional[SupportedDType] = None + self, item: Tuple[PathOrStr, int], dtype: Optional[SupportedDType] = None ) -> Tuple[int, int]: + path, idx = item dtype = dtype or self.dtype item_size = dtype(0).itemsize - file_size = self._get_size_from_offset_index(path) + file_size = self._get_size_from_offset_index(item) if ( self.max_target_sequence_length is None or self.max_target_sequence_length == self.sequence_length @@ -1741,12 +1748,13 @@ def _get_file_size_and_length( else: raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'") - def _get_size_from_offset_index(self, path: PathOrStr) -> int: + def _get_size_from_offset_index(self, path_index: Tuple[PathOrStr, int]) -> int: try: + path, idx = path_index # Get size in bytes from tokens in the supplied index * itemsize - return self._path_offset_index[str(path)] * self.dtype(0).itemsize + return self._path_offset_index[(str(path), idx)] * self.dtype(0).itemsize except KeyError: - raise OLMoEnvironmentError(f"Path {path} not found in path index") + raise OLMoEnvironmentError(f"Item not found in path index @ {path_index}") @dataclass @@ -1828,9 +1836,9 @@ def build(self) -> NumpyFSLDataset: """ Construct the corresponding :class:`NumpyFSLDatasetMixture`. """ - mixture = self.source_mixture_config.build().to_path_instance_index() + mixture = self.source_mixture_config.build() return NumpyFSLDatasetMixture( - *mixture.keys(), + *mixture.to_paths(), sequence_length=self.sequence_length or 1024, max_target_sequence_length=self.max_target_sequence_length, pad_token_id=self.tokenizer.pad_token_id, @@ -1840,6 +1848,6 @@ def build(self) -> NumpyFSLDataset: metadata=self.metadata, include_instance_metadata=self.include_instance_metadata, generate_doc_lengths=self.generate_doc_lengths, - path_offset_index=mixture, + path_offset_index=mixture.to_index(), bust_index_cache=self.bust_index_cache, ) diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py index bc4096fa..2d3d07e0 100644 --- a/src/olmo_core/data/source_mixture.py +++ b/src/olmo_core/data/source_mixture.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from itertools import chain from pprint import pprint -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import tabulate from tqdm import tqdm @@ -107,12 +107,25 @@ class SourceMixtureDataset: sources: List[SourceMixtureOutcome] - def to_path_instance_index(self) -> Dict[str, int]: + def to_index(self) -> Dict[Tuple[str, int], int]: """ - Convert the dataset to a dictionary of paths and instance counts to retain. + Convert the dataset to an indexed array of dict((int, path), int). """ - outcomes = chain.from_iterable([outcome.path_tokens for outcome in self.sources]) - return {str(outcome.path): outcome.tokens for outcome in outcomes} + return { + (str(outcome.path), idx): outcome.tokens + for idx, outcome in enumerate( + list(chain.from_iterable([outcome.path_tokens for outcome in self.sources])) + ) + } + + def to_paths(self) -> List[PathOrStr]: + """ + Convert the dataset to a list of paths while maintaining stable ordering. + """ + return [ + item.path + for item in list(chain.from_iterable([outcome.path_tokens for outcome in self.sources])) + ] @dataclass @@ -143,7 +156,7 @@ def build(self) -> SourceMixtureDataset: random.seed(self.seed) available_tokens_by_source: Dict[str, int] = {} - print("--------------------------------------------------------------------------------") + print("---------------------------------------------------------") print("Generating a source mixture from configurations:") for source_config in self.source_configs: pprint(source_config) @@ -157,7 +170,7 @@ def build(self) -> SourceMixtureDataset: tokens_details_by_source: List[SourceTokenDetails] = [] - # Calculate the number of tokens to include for each source + # Calculate the number of tokens available and to include for each source for source_config in self.source_configs: num_for_source = available_tokens_by_source[source_config.source_name] needed_for_source = int(self.max_tokens * source_config.target_ratio) @@ -187,7 +200,7 @@ def build(self) -> SourceMixtureDataset: name=source.config.source_name, path_tokens=self.get_paths_and_tokens_for_source( source_config=source.config, - take_ratio=source.num_selected / source.population, + token_details=source, ), ) ) @@ -227,13 +240,31 @@ def build(self) -> SourceMixtureDataset: return SourceMixtureDataset(completed) def get_paths_and_tokens_for_source( - self, source_config: SourceMixtureConfig, take_ratio: float + self, source_config: SourceMixtureConfig, token_details: SourceTokenDetails ) -> List[SourcePathTokens]: """ Get the paths and resulting token count for a source. """ - # TODO: Handle repetition ratio by adding paths multiple times, max_repetition_ratio + take_ratio = token_details.num_selected / token_details.population path_tokens = [] + + # When we need more than 1 repetition of the source data we have a take ration > 1 + if take_ratio > 1: + take_ratios = [] + remaining = take_ratio + + while remaining > 0: + chunk = min(1.0, remaining) + take_ratios.append(chunk) + remaining -= chunk + + for ratio in take_ratios: + for path in source_config.paths: + tokens_to_keep = int(math.ceil(self._count_tokens_for_file(path) * ratio)) + path_tokens.append(SourcePathTokens(path=path, tokens=tokens_to_keep)) + + return path_tokens + for path in source_config.paths: tokens_to_keep = int(math.ceil(self._count_tokens_for_file(path) * take_ratio)) path_tokens.append(SourcePathTokens(path=path, tokens=tokens_to_keep)) diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index 7bc29ad0..d29278d2 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -1,6 +1,7 @@ import gzip import math import os +import random from contextlib import contextmanager from pathlib import Path from typing import ( @@ -328,7 +329,7 @@ def memmap_to_write( file until the context exists successfully. """ path.parent.mkdir(exist_ok=True, parents=True) - tmp_path = path.with_suffix(".npy.tmp") + tmp_path = path.with_suffix(f".{random.randint(0,2**16)}.npy.tmp") mmap = np.memmap(tmp_path, dtype=dtype, mode="w+", shape=shape) try: yield mmap diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index c1a2e3e3..6981ebc0 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -101,8 +101,7 @@ def test_numpy_padded_fsl_dataset(tmp_path: Path): def test_numpy_fsl_mixture_dataset(tmp_path: Path): - # NOTE: At very small token counts (10's of tokens) - # the take_ratio gets finicky so we test at small but real world-ish scale) + # NOTE: At small token counts the take_ratio can be finicky so we test at small but real world-ish scale npdtype = np.uint16 seed = 42 mmap1 = _make_mmaps(tmp_path, "mmap1", 1, 20 * 1000, npdtype, eos=0, seed=seed) @@ -147,8 +146,72 @@ def test_numpy_fsl_mixture_dataset(tmp_path: Path): expected = "68144f" assert ds.fingerprint.endswith( expected - ), f"Fingerprint mismatch, expected {expected}, got {ds.fingerprint[-6:]}...Do you need to update the expected fingerprint?" - assert ds[0]["input_ids"].tolist() == [56422, 24545, 15795, 52202] + ), f"Fingerprint mismatch, expected {expected}, got {ds.fingerprint[-6:]}...Do you need to update expected fingerprint?" + assert ds[0]["input_ids"].tolist() == [ + 56422, + 24545, + 15795, + 52202, + ] # stable because we pass a seed + assert ds.num_tokens == 10000 + assert len(ds) == 2500 + + +def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path): + # NOTE: At small token counts the take_ratio can be finicky so we test at small but real world-ish scale + npdtype = np.uint16 + seed = 42 + mmap1 = _make_mmaps(tmp_path, "mmap1", 1, 10 * 1000, npdtype, eos=0, seed=seed) + mmap2 = _make_mmaps(tmp_path, "mmap2", 1, 20 * 1000, npdtype, eos=0, seed=seed) + + sequence_length = 4 + tokenizer = TokenizerConfig( + vocab_size=32_000, + eos_token_id=0, + pad_token_id=-1, + ) + + source1_paths = [i[0] for i in mmap1] * 2 # duplicate the paths + + mixture_config = SourceMixtureDatasetConfig( + max_tokens=10_000, + sequence_length=sequence_length, + source_configs=[ + SourceMixtureConfig( + source_name="mmap1", + paths=source1_paths, + target_ratio=0.8, + ), + SourceMixtureConfig( + source_name="mmap2", + paths=[i[0] for i in mmap2], + target_ratio=0.2, + ), + ], + dtype=NumpyDatasetDType.uint16, + processes=1, + seed=seed, + ) + + ds = NumpyFSLDatasetMixtureConfig( + source_mixture_config=mixture_config, + sequence_length=sequence_length, + tokenizer=tokenizer, + bust_index_cache=True, + include_instance_metadata=False, + ).build() + ds.prepare() + + expected = "190cd0" + assert ds.fingerprint.endswith( + expected + ), f"Fingerprint mismatch, expected {expected}, got {ds.fingerprint[-6:]}...Do you need to update expected fingerprint?" + assert ds[0]["input_ids"].tolist() == [ + 56422, + 24545, + 15795, + 52202, + ] # stable because we pass a seed assert ds.num_tokens == 10000 assert len(ds) == 2500 diff --git a/src/test/data/source_mixture_test.py b/src/test/data/source_mixture_test.py index adbf1af7..cfd1305e 100644 --- a/src/test/data/source_mixture_test.py +++ b/src/test/data/source_mixture_test.py @@ -226,7 +226,14 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): ) mixture = config.build() + sources = [source for source in mixture.sources] + all_paths = [] + for source in sources: + all_paths.extend([item for item in source.path_tokens]) + + total_tokens = sum([item.tokens for item in all_paths]) assert isinstance(mixture, SourceMixtureDataset) + assert total_tokens == 5_000_000 def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): @@ -268,3 +275,44 @@ def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): # we limit usage to 10% of the source with pytest.raises(OLMoConfigurationError): config.build() + + +# TODO: Handle duplicate paths in source mixture +def test_dataset_mixture_build_duplicate_paths(tmp_path: Path): + sources = { + "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=500_000), + "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + } + + source_configs = [ + SourceMixtureConfig( + source_name="1", + target_ratio=0.33, # 990k tokens + max_repetition_ratio=2.0, + paths=[sources["1"][0], sources["1"][0]], # Duplicate the 1 path for source 1 + ), + SourceMixtureConfig(source_name="2", target_ratio=0.33, paths=sources["2"]), + SourceMixtureConfig( + source_name="3", + target_ratio=0.34, + paths=sources["3"], + ), + ] + + max_tokens = 3_000_000 + + config = SourceMixtureDatasetConfig( + max_tokens=max_tokens, + source_configs=source_configs, + dtype=NumpyDatasetDType.uint32, + sequence_length=1024, + ) + + mixture = config.build() + index = mixture.to_index() + paths = mixture.to_paths() + assert paths == [sources["1"][0], sources["1"][0]] + sources["2"] + sources["3"] + assert len(index) == 6 + assert isinstance(mixture, SourceMixtureDataset) + assert len(mixture.sources) == 3 From c69b228a9814e1bd58cfd38d2b6cb71b60006da2 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 11:15:00 -0700 Subject: [PATCH 33/57] In case there a ton of files --- src/olmo_core/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index d29278d2..cc74deef 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -329,7 +329,7 @@ def memmap_to_write( file until the context exists successfully. """ path.parent.mkdir(exist_ok=True, parents=True) - tmp_path = path.with_suffix(f".{random.randint(0,2**16)}.npy.tmp") + tmp_path = path.with_suffix(f".{random.randint(0,2**32)}.npy.tmp") mmap = np.memmap(tmp_path, dtype=dtype, mode="w+", shape=shape) try: yield mmap From 18efafdcd3f59dcc95af9e16c5c70550c3ba98f8 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 11:26:55 -0700 Subject: [PATCH 34/57] Maybe fix trainer launch --- src/examples/train_with_mixture.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 41856757..14d7cce9 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -26,7 +26,12 @@ SourceMixtureDatasetConfig, ) from olmo_core.data.types import NumpyDatasetDType -from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType +from olmo_core.distributed.parallel import DataParallelType +from olmo_core.nn.transformer import ( + TransformerConfig, + TransformerDataParallelConfig, + TransformerDataParallelWrappingStrategy, +) from olmo_core.distributed.utils import init_hybrid_shard_mesh from olmo_core.nn.transformer import TransformerConfig from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride @@ -65,10 +70,14 @@ def build_config(run_name: str) -> ExperimentConfig: tokenizer_config = TokenizerConfig.dolma2() model_config = TransformerConfig.llama2_271M( - vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128 + # a little bigger than actual vocab size to make it a multiple of 128 + vocab_size=tokenizer_config.padded_vocab_size(), compile=True, - dp_config=DataParallelConfig( - name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 + dp_config=TransformerDataParallelConfig( + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + wrapping_strategy=TransformerDataParallelWrappingStrategy.full, ), ) From 5ceed463897d489dcf65e588f712d2497dd7d1a2 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 11:37:59 -0700 Subject: [PATCH 35/57] Match other example --- src/examples/train_with_mixture.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 14d7cce9..8ea7c685 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -77,7 +77,6 @@ def build_config(run_name: str) -> ExperimentConfig: name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32, - wrapping_strategy=TransformerDataParallelWrappingStrategy.full, ), ) From 4c7513e63e4d4967fce854d3e0384121dd9d2f70 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 14:00:20 -0700 Subject: [PATCH 36/57] More tests --- src/test/data/data_loader_test.py | 61 +++++++++++++++ src/test/data/fixtures.py | 71 +++++++++++++++++ src/test/data/numpy_dataset_test.py | 39 ++-------- src/test/data/source_mixture_test.py | 109 +++++++++++++++------------ src/test/utils.py | 34 +++++++++ 5 files changed, 231 insertions(+), 83 deletions(-) create mode 100644 src/test/data/fixtures.py diff --git a/src/test/data/data_loader_test.py b/src/test/data/data_loader_test.py index 3858b605..13629154 100644 --- a/src/test/data/data_loader_test.py +++ b/src/test/data/data_loader_test.py @@ -1,3 +1,4 @@ +from itertools import chain from pathlib import Path from typing import List @@ -15,6 +16,8 @@ VSLNaturalCurriculum, ) +from .fixtures import get_fsl_mixture + @pytest.mark.parametrize( "num_tokens, sequence_length, world_size, num_workers, num_threads, batch_size", @@ -184,6 +187,64 @@ def test_fsl_data_loader_multiple_epochs( assert data_loader.tokens_processed == 0 +@pytest.mark.parametrize( + "num_tokens, sequence_length, world_size, num_workers, num_threads, batch_size", + [ + (100, 4, 2, 2, 2, 8), # 2 instances per batch, 12 instances total + ], +) +def test_fsl_data_loader_with_mixture( + tmp_path: Path, + num_tokens: int, + sequence_length: int, + world_size: int, + num_workers: int, + num_threads: int, + batch_size: int, # in tokens +): + + dataset = get_fsl_mixture(tmp_path, sequence_length=sequence_length, num_tokens=num_tokens) + assert batch_size % sequence_length == 0 + assert batch_size % world_size == 0 + rank_batch_size = batch_size // world_size + assert rank_batch_size > 0 + num_batches = num_tokens // batch_size + + def get_all_batches() -> List[List[int]]: + all_batches: List[List[int]] = [[] for _ in range(num_batches)] + for rank in range(world_size): + data_loader = NumpyFSLDataLoader( + dataset, + global_batch_size=batch_size, + collator=DataCollator(pad_token_id=-1), + shuffle=False, + num_threads=num_threads, + work_dir=tmp_path, + dp_rank=rank, + dp_world_size=world_size, + num_workers=num_workers, + ) + data_loader.reshuffle(epoch=1) + batches = list(data_loader) + assert len(batches) == num_batches + for i, batch in enumerate(batches): + for instance in batch["input_ids"]: + all_batches[i].extend(instance.tolist()) + return all_batches + + all_batches = get_all_batches() + all_tokens = [] + assert len(all_batches) == num_batches + for batch in all_batches: + assert len(batch) == batch_size + all_tokens.extend(batch) + + ds_tokens = list(chain.from_iterable([i["input_ids"].tolist() for i in dataset])) + + assert len(all_tokens) == num_batches * batch_size + assert set(all_tokens) == set(ds_tokens) + + @pytest.mark.parametrize( "shuffle", [pytest.param(True, id="shuffle"), pytest.param(False, id="no-shuffle")] ) diff --git a/src/test/data/fixtures.py b/src/test/data/fixtures.py new file mode 100644 index 00000000..6a48bc87 --- /dev/null +++ b/src/test/data/fixtures.py @@ -0,0 +1,71 @@ +from pathlib import Path +from typing import Type, Union + +import numpy as np + +from olmo_core.data import ( + NumpyFSLDataset, + NumpyFSLDatasetMixtureConfig, + TokenizerConfig, +) +from olmo_core.data.types import NumpyDatasetDType +from olmo_core.data.source_mixture import ( + SourceMixtureConfig, + SourceMixtureDatasetConfig, +) + +from ..utils import mk_mmaps + + +def get_fsl_mixture( + tmp_path: Path, + dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint32, + seed: int = 42, + sequence_length: int = 4, + num_tokens: int = 20 * 1000, + eos: int = 0, +) -> NumpyFSLDataset: + seed = 42 + mmap1 = mk_mmaps( + tmp_path, "mmap1", 1, num_tokens // 2, dtype, eos=eos, seed=seed, seq_length=sequence_length + ) + mmap2 = mk_mmaps( + tmp_path, "mmap2", 1, num_tokens // 2, dtype, eos=eos, seed=seed, seq_length=sequence_length + ) + + tokenizer = TokenizerConfig( + vocab_size=32_000, + eos_token_id=eos, + pad_token_id=-1, + ) + + mixture_config = SourceMixtureDatasetConfig( + max_tokens=num_tokens, + sequence_length=sequence_length, + source_configs=[ + SourceMixtureConfig( + source_name="mmap1", + paths=[i[0] for i in mmap1], + target_ratio=0.8, + ), + SourceMixtureConfig( + source_name="mmap2", + paths=[i[0] for i in mmap2], + target_ratio=0.2, + ), + ], + dtype=NumpyDatasetDType.uint16, + processes=1, + seed=seed, + ) + + ds = NumpyFSLDatasetMixtureConfig( + source_mixture_config=mixture_config, + sequence_length=sequence_length, + tokenizer=tokenizer, + bust_index_cache=True, + include_instance_metadata=False, + ).build() + ds.prepare() + + return ds diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index 6981ebc0..11b62d4b 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -1,6 +1,5 @@ -from os import PathLike from pathlib import Path -from typing import Any, List, Tuple, Union +from typing import List import numpy as np @@ -19,33 +18,7 @@ from olmo_core.data.types import NumpyDatasetDType from olmo_core.data.utils import get_document_indices, write_document_indices -Mmaps = List[Tuple[Union[Path, PathLike[Any], str], Any]] - - -def _make_mmaps( - tmp_path: Path, - prefix: str, - num_files: int, - size: int, - dtype, - eos: int, - seq_length: int = 4, - seed: int = 42, -) -> Mmaps: - mmaps: Mmaps = [] - for i in range(num_files): - filepath = f"{tmp_path}/{prefix}_{i}.npy" - np.random.seed(seed) - data = np.random.randint(0, np.iinfo(dtype).max, size=size, dtype=dtype) - data = np.append( - np.insert(data, np.arange(seq_length + 1, len(data), seq_length), eos), eos - ) - mm = np.memmap(filepath, mode="w+", dtype=dtype, shape=(len(data),)) - mm[:] = data - mm.flush() - mmaps.append((Path(filepath), data)) - - return mmaps +from ..utils import mk_mmaps def test_numpy_fsl_dataset(tmp_path: Path): @@ -104,8 +77,8 @@ def test_numpy_fsl_mixture_dataset(tmp_path: Path): # NOTE: At small token counts the take_ratio can be finicky so we test at small but real world-ish scale npdtype = np.uint16 seed = 42 - mmap1 = _make_mmaps(tmp_path, "mmap1", 1, 20 * 1000, npdtype, eos=0, seed=seed) - mmap2 = _make_mmaps(tmp_path, "mmap2", 1, 20 * 1000, npdtype, eos=0, seed=seed) + mmap1 = mk_mmaps(tmp_path, "mmap1", 1, 20 * 1000, npdtype, eos=0, seed=seed) + mmap2 = mk_mmaps(tmp_path, "mmap2", 1, 20 * 1000, npdtype, eos=0, seed=seed) sequence_length = 4 tokenizer = TokenizerConfig( @@ -161,8 +134,8 @@ def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path): # NOTE: At small token counts the take_ratio can be finicky so we test at small but real world-ish scale npdtype = np.uint16 seed = 42 - mmap1 = _make_mmaps(tmp_path, "mmap1", 1, 10 * 1000, npdtype, eos=0, seed=seed) - mmap2 = _make_mmaps(tmp_path, "mmap2", 1, 20 * 1000, npdtype, eos=0, seed=seed) + mmap1 = mk_mmaps(tmp_path, "mmap1", 1, 10 * 1000, npdtype, eos=0, seed=seed) + mmap2 = mk_mmaps(tmp_path, "mmap2", 1, 20 * 1000, npdtype, eos=0, seed=seed) sequence_length = 4 tokenizer = TokenizerConfig( diff --git a/src/test/data/source_mixture_test.py b/src/test/data/source_mixture_test.py index cfd1305e..b640be4d 100644 --- a/src/test/data/source_mixture_test.py +++ b/src/test/data/source_mixture_test.py @@ -14,42 +14,44 @@ ) from olmo_core.exceptions import OLMoConfigurationError -Mmaps = List[Union[Path, PathLike[Any], str]] +from ..utils import mk_mmaps -def _make_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> Mmaps: - mmaps: Mmaps = [] - for i in range(num_files): - filepath = f"{tmp_path}/{prefix}_{i}.npy" - data = np.random.randint(0, 2**32, size=size, dtype=np.uint32) - mm = np.memmap( - filepath, mode="w+", dtype=NumpyDatasetDType.uint32.as_np_dtype(), shape=(size,) - ) - mm[:] = data - mm.flush() - mmaps.append(Path(filepath)) +# def mk_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> Mmaps: +# mmaps: Mmaps = [] +# for i in range(num_files): +# filepath = f"{tmp_path}/{prefix}_{i}.npy" +# data = np.random.randint(0, 2**32, size=size, dtype=np.uint32) +# mm = np.memmap( +# filepath, mode="w+", dtype=NumpyDatasetDType.uint32.as_np_dtype(), shape=(size,) +# ) +# mm[:] = data +# mm.flush() +# mmaps.append(Path(filepath)) - return mmaps +# return mmaps def test_source_mixture_config(tmp_path: Path, capsys): source_paths = { - "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), - "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), - "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ SourceMixtureConfig( source_name="1", target_ratio=0.33, - paths=source_paths["1"], + paths=[i[0] for i in source_paths["1"]], + ), + SourceMixtureConfig( + source_name="2", target_ratio=0.33, paths=[i[0] for i in source_paths["2"]] ), - SourceMixtureConfig(source_name="2", target_ratio=0.33, paths=source_paths["2"]), SourceMixtureConfig( source_name="3", target_ratio=0.34, - paths=source_paths["3"], + paths=[i[0] for i in source_paths["3"]], ), ] @@ -123,22 +125,24 @@ def test_dataset_mixture_config_validation(): def test_dataset_mixture_build(tmp_path: Path): source_paths = { - "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), - "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), - "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ SourceMixtureConfig( source_name="1", target_ratio=0.33, - paths=source_paths["1"], + paths=[i[0] for i in source_paths["1"]], + ), + SourceMixtureConfig( + source_name="2", target_ratio=0.33, paths=[i[0] for i in source_paths["2"]] ), - SourceMixtureConfig(source_name="2", target_ratio=0.33, paths=source_paths["2"]), SourceMixtureConfig( source_name="3", target_ratio=0.34, - paths=source_paths["3"], + paths=[i[0] for i in source_paths["3"]], ), ] @@ -157,21 +161,23 @@ def test_dataset_mixture_build(tmp_path: Path): def test_dataset_mixture_build_insufficient_source_data(tmp_path: Path): source_paths = { - "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), - "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), - "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ SourceMixtureConfig( source_name="1", target_ratio=0.5, - paths=source_paths["1"], + paths=[i[0] for i in source_paths["1"]], + ), + SourceMixtureConfig( + source_name="2", target_ratio=0.25, paths=[i[0] for i in source_paths["2"]] ), - SourceMixtureConfig(source_name="2", target_ratio=0.25, paths=source_paths["2"]), SourceMixtureConfig( source_name="3", target_ratio=0.25, - paths=source_paths["3"], + paths=[i[0] for i in source_paths["3"]], ), ] @@ -196,9 +202,9 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): Source 1 has a target ratio of 90% and a max repetition ratio of 4.0, so it should be possible to meet the target of 3600 tokens with 1 file of 1000 tokens repeated 4 times. """ source_paths = { - "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), - "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), - "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ @@ -206,13 +212,15 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): source_name="1", target_ratio=0.5, max_repetition_ratio=3.0, # Allow 3x repetition of source1 so that we can meet the target of 2.5M - paths=source_paths["1"], + paths=[i[0] for i in source_paths["1"]], + ), + SourceMixtureConfig( + source_name="2", target_ratio=0.25, paths=[i[0] for i in source_paths["2"]] ), - SourceMixtureConfig(source_name="2", target_ratio=0.25, paths=source_paths["2"]), SourceMixtureConfig( source_name="3", target_ratio=0.25, - paths=source_paths["3"], + paths=[i[0] for i in source_paths["3"]], ), ] @@ -238,26 +246,26 @@ def test_dataset_mixture_build_with_repetition(tmp_path: Path): def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): source_paths = { - "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), - "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), - "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=1_000_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ SourceMixtureConfig( source_name="1", target_ratio=0.25, - paths=source_paths["1"], + paths=[i[0] for i in source_paths["1"]], max_source_fraction=0.10, # Allow only 10% of source1 to be used (population is 1M tokens) ), SourceMixtureConfig( source_name="2", target_ratio=0.25, - paths=source_paths["2"], + paths=[i[0] for i in source_paths["2"]], ), SourceMixtureConfig( source_name="3", target_ratio=0.5, - paths=source_paths["3"], + paths=[i[0] for i in source_paths["3"]], ), ] @@ -280,9 +288,9 @@ def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): # TODO: Handle duplicate paths in source mixture def test_dataset_mixture_build_duplicate_paths(tmp_path: Path): sources = { - "1": _make_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=500_000), - "2": _make_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), - "3": _make_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), + "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=500_000), + "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), + "3": mk_mmaps(tmp_path=tmp_path, prefix="source3", num_files=2, size=1_000_000), } source_configs = [ @@ -290,13 +298,13 @@ def test_dataset_mixture_build_duplicate_paths(tmp_path: Path): source_name="1", target_ratio=0.33, # 990k tokens max_repetition_ratio=2.0, - paths=[sources["1"][0], sources["1"][0]], # Duplicate the 1 path for source 1 + paths=[sources["1"][0][0], sources["1"][0][0]], # Duplicate the 1 path for source 1 ), - SourceMixtureConfig(source_name="2", target_ratio=0.33, paths=sources["2"]), + SourceMixtureConfig(source_name="2", target_ratio=0.33, paths=[i[0] for i in sources["2"]]), SourceMixtureConfig( source_name="3", target_ratio=0.34, - paths=sources["3"], + paths=[i[0] for i in sources["3"]], ), ] @@ -309,10 +317,11 @@ def test_dataset_mixture_build_duplicate_paths(tmp_path: Path): sequence_length=1024, ) + expected = [sources["1"][0][0]] + [item[0] for item in list(chain(*sources.values()))] mixture = config.build() index = mixture.to_index() paths = mixture.to_paths() - assert paths == [sources["1"][0], sources["1"][0]] + sources["2"] + sources["3"] + assert paths == expected assert len(index) == 6 assert isinstance(mixture, SourceMixtureDataset) assert len(mixture.sources) == 3 diff --git a/src/test/utils.py b/src/test/utils.py index 4b04415e..d09e6e57 100644 --- a/src/test/utils.py +++ b/src/test/utils.py @@ -1,3 +1,8 @@ +from os import PathLike +from pathlib import Path +from typing import Any, List, Tuple, Type, Union + +import numpy as np import pytest import torch @@ -89,3 +94,32 @@ def get_default_device(): return torch.device("cuda") else: return torch.device("cpu") + + +Mmaps = List[Tuple[Union[Path, PathLike[Any], str], Any]] + + +def mk_mmaps( + tmp_path: Path, + prefix: str, + num_files: int, + size: int, + dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint32, + eos: int = 0, + seq_length: int = 4, + seed: int = 42, +) -> Mmaps: + mmaps: Mmaps = [] + for i in range(num_files): + filepath = f"{tmp_path}/{prefix}_{i}.npy" + np.random.seed(seed) + data = np.random.randint(0, np.iinfo(dtype).max, size=size, dtype=dtype) + data = np.append( + np.insert(data, np.arange(seq_length + 1, len(data), seq_length), eos), eos + ) + mm = np.memmap(filepath, mode="w+", dtype=dtype, shape=(len(data),)) + mm[:] = data + mm.flush() + mmaps.append((Path(filepath), data)) + + return mmaps From 8401580a4a9f0300bed2d10e96ffc57bf13f3d7f Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 14:03:54 -0700 Subject: [PATCH 37/57] Try diff gpus --- src/examples/train_with_mixture.py | 2 +- src/examples/train_with_mixture_launch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index 8ea7c685..bb251af7 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -74,7 +74,7 @@ def build_config(run_name: str) -> ExperimentConfig: vocab_size=tokenizer_config.padded_vocab_size(), compile=True, dp_config=TransformerDataParallelConfig( - name=DataParallelType.fsdp, + name=DataParallelType.ddp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32, ), diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py index 705d21bd..dd6899a9 100644 --- a/src/examples/train_with_mixture_launch.py +++ b/src/examples/train_with_mixture_launch.py @@ -22,7 +22,7 @@ def build_config(run_name: str) -> BeakerLaunchConfig: task_name="train", workspace="ai2/OLMo-core", description="Testing OLMo-core launch utilities", - clusters=["ai2/allennlp-elanding-a100-40g"], + clusters=["ai2/allennlp-cirrascale"], env_secrets=[ BeakerEnvSecret(name="BEAKER_TOKEN", secret=f"{beaker_user}_BEAKER_TOKEN"), BeakerEnvSecret(name="WANDB_API_KEY", secret=f"{beaker_user}_WANDB_API_KEY"), From 0d77422fcdc2e8ec6596f8bef621982a7a09edfa Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 14:09:59 -0700 Subject: [PATCH 38/57] keep fsdp --- src/examples/train_with_mixture.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index bb251af7..dbfdf157 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -74,7 +74,7 @@ def build_config(run_name: str) -> ExperimentConfig: vocab_size=tokenizer_config.padded_vocab_size(), compile=True, dp_config=TransformerDataParallelConfig( - name=DataParallelType.ddp, + name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32, ), @@ -105,7 +105,7 @@ def build_config(run_name: str) -> ExperimentConfig: SourceMixtureConfig( paths=[f"s3://{path}" for path in baseline], source_name="baseline", - max_repetition_ratio=1.0, # 1.0 is a no-op but added here to illustrate the option + max_repetition_ratio=1.0, # 1.0 is default but here to illustrate options target_ratio=0.8, ), SourceMixtureConfig( @@ -114,6 +114,7 @@ def build_config(run_name: str) -> ExperimentConfig: target_ratio=0.2, ), ], + processes=10, dtype=NumpyDatasetDType.uint32, seed=42, ) From d22ed1036845574a8add41a98742b040ca049dd5 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 14:12:10 -0700 Subject: [PATCH 39/57] checks --- src/examples/train_with_mixture.py | 7 +------ src/test/data/data_loader_test.py | 1 - src/test/data/fixtures.py | 2 +- src/test/data/source_mixture_test.py | 18 ------------------ 4 files changed, 2 insertions(+), 26 deletions(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index dbfdf157..d3d7e68b 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -27,13 +27,8 @@ ) from olmo_core.data.types import NumpyDatasetDType from olmo_core.distributed.parallel import DataParallelType -from olmo_core.nn.transformer import ( - TransformerConfig, - TransformerDataParallelConfig, - TransformerDataParallelWrappingStrategy, -) from olmo_core.distributed.utils import init_hybrid_shard_mesh -from olmo_core.nn.transformer import TransformerConfig +from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride from olmo_core.train import ( Duration, diff --git a/src/test/data/data_loader_test.py b/src/test/data/data_loader_test.py index 13629154..8edaddda 100644 --- a/src/test/data/data_loader_test.py +++ b/src/test/data/data_loader_test.py @@ -202,7 +202,6 @@ def test_fsl_data_loader_with_mixture( num_threads: int, batch_size: int, # in tokens ): - dataset = get_fsl_mixture(tmp_path, sequence_length=sequence_length, num_tokens=num_tokens) assert batch_size % sequence_length == 0 assert batch_size % world_size == 0 diff --git a/src/test/data/fixtures.py b/src/test/data/fixtures.py index 6a48bc87..b7590c3b 100644 --- a/src/test/data/fixtures.py +++ b/src/test/data/fixtures.py @@ -8,11 +8,11 @@ NumpyFSLDatasetMixtureConfig, TokenizerConfig, ) -from olmo_core.data.types import NumpyDatasetDType from olmo_core.data.source_mixture import ( SourceMixtureConfig, SourceMixtureDatasetConfig, ) +from olmo_core.data.types import NumpyDatasetDType from ..utils import mk_mmaps diff --git a/src/test/data/source_mixture_test.py b/src/test/data/source_mixture_test.py index b640be4d..743fd182 100644 --- a/src/test/data/source_mixture_test.py +++ b/src/test/data/source_mixture_test.py @@ -1,9 +1,6 @@ from itertools import chain -from os import PathLike from pathlib import Path -from typing import Any, List, Union -import numpy as np import pytest from olmo_core.data import NumpyDatasetDType @@ -17,21 +14,6 @@ from ..utils import mk_mmaps -# def mk_mmaps(tmp_path: Path, prefix: str, num_files: int, size: int) -> Mmaps: -# mmaps: Mmaps = [] -# for i in range(num_files): -# filepath = f"{tmp_path}/{prefix}_{i}.npy" -# data = np.random.randint(0, 2**32, size=size, dtype=np.uint32) -# mm = np.memmap( -# filepath, mode="w+", dtype=NumpyDatasetDType.uint32.as_np_dtype(), shape=(size,) -# ) -# mm[:] = data -# mm.flush() -# mmaps.append(Path(filepath)) - -# return mmaps - - def test_source_mixture_config(tmp_path: Path, capsys): source_paths = { "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), From 68a4d28003f5d2c2c5a3699bfaa199bed34f8595 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 14:13:46 -0700 Subject: [PATCH 40/57] Less tokens --- src/examples/train_with_mixture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index d3d7e68b..d4f6cf07 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -94,7 +94,7 @@ def build_config(run_name: str) -> ExperimentConfig: sequence_length = 1024 source_config = SourceMixtureDatasetConfig( - max_tokens=int(10e8), # 100M tokens + max_tokens=int(10e7), # 100M tokens sequence_length=sequence_length, source_configs=[ SourceMixtureConfig( From c35514c99f9ab275699e1c167448850d55598a3a Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Fri, 25 Oct 2024 15:48:08 -0700 Subject: [PATCH 41/57] Exclude ai2/allennlp-elanding-a100-40g temp --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cba7adb2..1708052e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -181,7 +181,7 @@ jobs: constraints: cluster: - ai2/allennlp-cirrascale - - ai2/allennlp-elanding-a100-40g + # - ai2/allennlp-elanding-a100-40g - ai2/pluto-cirrascale - ai2/jupiter-cirrascale-2 envVars: From c453e65fde22c86051fc716189b9b6f0ade5c1fa Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:15:34 -0700 Subject: [PATCH 42/57] Feedback --- pyproject.toml | 2 +- src/examples/train_with_mixture.py | 5 +- src/olmo_core/data/__init__.py | 2 - src/olmo_core/data/numpy_dataset.py | 497 ++++++++++++--------------- src/olmo_core/data/source_mixture.py | 94 +++-- src/olmo_core/data/types.py | 2 +- src/test/data/data_loader_test.py | 57 --- src/test/data/fixtures.py | 11 +- src/test/data/numpy_dataset_test.py | 9 +- src/test/data/source_mixture_test.py | 10 +- 10 files changed, 284 insertions(+), 405 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e2b99313..8f24e277 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ dependencies = [ "omegaconf", "safetensors", "importlib_resources", - "s3fs", # REMOVE THIS IN FAVOR OF SOMETHING CONSISTENT ELSEWHERE "tabulate", "tqdm", ] @@ -53,6 +52,7 @@ dev = [ "sphinx-copybutton==0.5.2", "sphinx-autobuild==2021.3.14", "sphinx-autodoc-typehints==1.23.3", + "s3fs", ] beaker = [ "beaker-py", diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py index d4f6cf07..4451e5c2 100644 --- a/src/examples/train_with_mixture.py +++ b/src/examples/train_with_mixture.py @@ -18,7 +18,6 @@ NumpyDataLoaderConfig, NumpyDatasetConfig, NumpyDatasetType, - NumpyFSLDatasetMixtureConfig, TokenizerConfig, ) from olmo_core.data.source_mixture import ( @@ -55,7 +54,7 @@ class ExperimentConfig(Config): model: TransformerConfig optim: AdamWConfig - dataset: NumpyFSLDatasetMixtureConfig + dataset: NumpyDatasetConfig data_loader: NumpyDataLoaderConfig trainer: TrainerConfig init_seed: int = 12536 @@ -114,7 +113,7 @@ def build_config(run_name: str) -> ExperimentConfig: seed=42, ) - dataset_config = NumpyFSLDatasetMixtureConfig( + dataset_config = NumpyDatasetConfig( source_mixture_config=source_config, sequence_length=sequence_length, max_target_sequence_length=8192, diff --git a/src/olmo_core/data/__init__.py b/src/olmo_core/data/__init__.py index 07b4bfa5..b710100e 100644 --- a/src/olmo_core/data/__init__.py +++ b/src/olmo_core/data/__init__.py @@ -25,7 +25,6 @@ NumpyDatasetBase, NumpyDatasetConfig, NumpyFSLDataset, - NumpyFSLDatasetMixtureConfig, NumpyPaddedFSLDataset, NumpyVSLDataset, VSLCurriculum, @@ -42,7 +41,6 @@ __all__ = [ "NumpyDatasetBase", "NumpyFSLDataset", - "NumpyFSLDatasetMixtureConfig", "NumpyPaddedFSLDataset", "NumpyVSLDataset", "VSLCurriculum", diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 4b0d923d..f2764818 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -30,7 +30,7 @@ from torch.utils.data import Dataset from olmo_core.data.source_mixture import SourceMixtureDatasetConfig -from olmo_core.data.types import NumpyDatasetDType, NumpyDatasetType, SupportedDType +from olmo_core.data.types import NumpyDatasetDType, NumpyDatasetType, NumpyUIntTypes from olmo_core.exceptions import OLMoConfigurationError, OLMoEnvironmentError from ..aliases import PathOrStr @@ -55,7 +55,6 @@ __all__ = [ "NumpyDatasetBase", "NumpyFSLDataset", - "NumpyFSLDatasetMixtureConfig", "NumpyPaddedFSLDataset", "VSLCurriculum", "VSLNaturalCurriculum", @@ -100,7 +99,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: SupportedDType = np.uint16, + dtype: NumpyUIntTypes = np.uint16, ): if not paths: raise OLMoConfigurationError("At least one path is required") @@ -154,7 +153,7 @@ def vocab_size(self) -> int: @property def dtype( self, - ) -> SupportedDType: + ) -> NumpyUIntTypes: """ The numpy datatype of the arrays. """ @@ -209,6 +208,13 @@ def work_dir_set(self) -> bool: """ return self._work_dir_set + @property + def num_tokens(self) -> int: + """ + Get the total number of tokens in the dataset. + """ + raise NotImplementedError + def _get_file_size(self, path: PathOrStr): path_idx = self.paths.index(path) return self.file_sizes[path_idx] @@ -270,16 +276,9 @@ def map( raise ValueError(method) with executor_class(max_workers=max_workers) as executor: - path_to_future = {} - for idx, path in enumerate(paths): - if path not in path_to_future: - path_to_future[path] = executor.submit(func, (path, idx)) - - results = [] - for path in paths: - results.append(path_to_future[path].result()) + futures = [executor.submit(func, (path, idx)) for idx, path in enumerate(paths)] - return results + return [future.result() for future in futures] def prepare(self): """ @@ -348,7 +347,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: SupportedDType = np.uint16, + dtype: NumpyUIntTypes = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, generate_doc_lengths: bool = False, @@ -485,7 +484,7 @@ def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor: ) def _get_file_size_and_length( - self, item: Tuple[PathOrStr, int], dtype: Optional[SupportedDType] = None + self, item: Tuple[PathOrStr, int], dtype: Optional[NumpyUIntTypes] = None ) -> Tuple[int, int]: path, _ = item dtype = dtype or self.dtype @@ -507,6 +506,156 @@ def _get_file_size_and_length( raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'") +class NumpyFSLDatasetMixture(NumpyFSLDataset): + """ + A version of :class:`NumpyFSLDataset` built from a mixture of sources and their expected token ratios relative to each other. A path_offset_index is used to determine the number of instances to retain from a path when constructing the local indices. + """ + + def __init__( + self, + *paths: PathOrStr, + path_offset_index: Dict[Tuple[str, int], int], + sequence_length: int, + pad_token_id: int, + eos_token_id: int, + vocab_size: int, + dtype: NumpyUIntTypes = np.uint16, + metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + include_instance_metadata: Optional[bool] = None, + generate_doc_lengths: bool = False, + max_target_sequence_length: Optional[int] = None, + bust_index_cache: bool = False, + ): + if max_target_sequence_length is not None and ( + max_target_sequence_length < sequence_length + or max_target_sequence_length % sequence_length != 0 + ): + raise OLMoConfigurationError( + "'max_target_sequence_length' should be a multiple of 'sequence_length'" + ) + + if include_instance_metadata is None and metadata: + include_instance_metadata = True + + if isinstance(metadata, list): + if len(metadata) != len(paths): + raise OLMoConfigurationError( + "'metadata' should have the same length as the number of file paths" + ) + else: + metadata = [metadata or {}] * len(paths) + + super().__init__( + *paths, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + vocab_size=vocab_size, + dtype=dtype, + sequence_length=sequence_length, + metadata=metadata, + include_instance_metadata=include_instance_metadata, + generate_doc_lengths=generate_doc_lengths, + max_target_sequence_length=max_target_sequence_length, + ) + self._metadata = tuple(metadata) + self._include_instance_metadata = include_instance_metadata + self._num_instances: Optional[int] = None + self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None + self._lengths_dtype: Optional[NumpyUIntTypes] = None + self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None + self._path_offset_index = path_offset_index + self._bust_index_cache = bust_index_cache + + def prepare(self): + if self.fs_local_rank == 0: + log.info("Gathering indices...") + self._write_document_indices() + barrier() + len(self) + + def _get_indices_path(self, path: PathOrStr) -> Path: + sha256_hash = hashlib.sha256() + sha256_hash.update(str(path).encode()) + sha256_hash.update(str(self._get_file_size(path)).encode()) + path_hash = sha256_hash.hexdigest() + return ( + self.work_dir + / "dataset-common" + / f"mixture-instance-indices-{self.sequence_length}-{path_hash}.npy" + ) + + def _write_document_indices(self): + paths_needed: List[Tuple[PathOrStr, int]] = [] + for idx, path in enumerate(self.paths): + indices_path = self._get_indices_path(path) + if indices_path.is_file() and not self._bust_index_cache: + log.info(f"Reusing document indices for '{path}' at:\n'{indices_path}'") + elif path not in paths_needed: + paths_needed.append((path, idx)) + + if paths_needed: + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] + for path, idx in paths_needed: + indices_path = self._get_indices_path(path) + log.info(f"Gathering instance indices for '{path}'...") + # NOTE: We limit the number of instances by total target token count // sequence length + max_instances = ( + self._path_offset_index[(str(path), idx)] // self.sequence_length + ) + future = executor.submit( + run_worker_func, + segment_documents_into_instances, + path, + indices_path, + max_sequence_length=self.sequence_length, + eos_token_id=self.eos_token_id, + dtype=self.dtype, + indices_dtype=self.dtype, + max_instances=max_instances, + ) + futures.append(future) + + concurrent.futures.wait(futures, return_when="ALL_COMPLETED") + + # Log results. + for path, future in zip([item[0] for item in paths_needed], futures): + _, total_instances = future.result() + log.info( + f"Created {total_instances:,d} instances of sequence length up to " + f"{self.sequence_length} from '{path}'" + ) + + def _get_file_size_and_length( + self, item: Tuple[PathOrStr, int], dtype: Optional[NumpyUIntTypes] = None + ) -> Tuple[int, int]: + dtype = dtype or self.dtype + item_size = dtype(0).itemsize + file_size = self._get_size_from_offset_index(item) + if ( + self.max_target_sequence_length is None + or self.max_target_sequence_length == self.sequence_length + ): + return file_size, file_size // (item_size * self.sequence_length) + elif self.max_target_sequence_length > self.sequence_length: + num_max_seq_len_instances = file_size // (item_size * self.max_target_sequence_length) + return ( + file_size, + num_max_seq_len_instances + * (self.max_target_sequence_length // self.sequence_length), + ) + else: + raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'") + + def _get_size_from_offset_index(self, path_index: Tuple[PathOrStr, int]) -> int: + try: + path, idx = path_index + # Get size in bytes from tokens in the supplied index * itemsize + return self._path_offset_index[(str(path), idx)] * self.dtype(0).itemsize + except KeyError: + raise OLMoEnvironmentError(f"Item not found in path index @ {path_index}") + + class NumpyPaddedFSLDataset(NumpyFSLDataset): """ A version of :class:`NumpyFSLDataset` that creates a single instance from each document. @@ -520,7 +669,7 @@ def __init__( pad_token_id: int, eos_token_id: int, vocab_size: int, - dtype: SupportedDType = np.uint16, + dtype: NumpyUIntTypes = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, ): @@ -555,7 +704,7 @@ def offsets(self) -> Tuple[Tuple[int, int], ...]: @property def indices_dtype( self, - ) -> SupportedDType: + ) -> NumpyUIntTypes: return np.uint32 def prepare(self): @@ -950,7 +1099,7 @@ def __init__( max_sequence_length: int, min_sequence_length: int = 256, curriculum: Optional[VSLCurriculum] = None, - dtype: SupportedDType = np.uint16, + dtype: NumpyUIntTypes = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, ): @@ -990,7 +1139,7 @@ def __init__( self._curriculum = curriculum or VSLNaturalCurriculum() self._num_instances: Optional[int] = None self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None - self._lengths_dtype: Optional[SupportedDType] = None + self._lengths_dtype: Optional[NumpyUIntTypes] = None self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None @property @@ -1229,13 +1378,13 @@ def instances_per_bucket(self) -> Tuple[Tuple[int, int], ...]: @property def indices_dtype( self, - ) -> SupportedDType: + ) -> NumpyUIntTypes: return np.uint32 @property def lengths_dtype( self, - ) -> SupportedDType: + ) -> NumpyUIntTypes: if self._lengths_dtype is None: for dtype in ( np.uint8, @@ -1334,6 +1483,14 @@ class NumpyDatasetConfig(Config): """ The type of dataset. """ + bust_index_cache: bool = False + """ + Whether or not to bust the index cache. + """ + source_mixture_config: Optional[SourceMixtureDatasetConfig] = None + """ + The source mixture dataset config. + """ sequence_length: Optional[int] = None """ The sequence length for a :class:`NumpyFSLDataset`. @@ -1407,6 +1564,10 @@ def validate(self): self.sequence_length = None self.max_target_sequence_length = None + if self.source_mixture_config and self.mix: + # NOTE(tylerm): This could be revisited as I think they could play nicely together. + raise OLMoConfigurationError("Only one of 'source_mixture_config' or 'mix' can be set") + @property def effective_sequence_length(self) -> int: if self.sequence_length is not None: @@ -1450,7 +1611,7 @@ def from_data_mix( def get_dtype( self, - ) -> SupportedDType: + ) -> NumpyUIntTypes: if self.dtype is not None: return NumpyDatasetDType(self.dtype).as_np_dtype() @@ -1471,8 +1632,10 @@ def build(self) -> NumpyDatasetBase: """ Construct the corresponding :class:`NumpyDatasetBase`. """ - if (self.paths is None) == (self.mix is None): - raise OLMoConfigurationError("Exactly one of 'paths' or 'mix' is required") + if (self.paths is None) == (self.mix is None) == (self.source_mixture_config is None): + raise OLMoConfigurationError( + "Exactly one of 'paths' or 'mix' or 'source_mixture' is required" + ) paths: List[str] = [] metadata = self.metadata @@ -1489,6 +1652,8 @@ def build(self) -> NumpyDatasetBase: paths.extend(matches) elif self.paths: paths = self.paths + elif self.source_mixture_config: + log.info("Building dataset from source mixture...") else: assert self.mix is not None if self.mix_base_dir is None: @@ -1525,18 +1690,35 @@ def build(self) -> NumpyDatasetBase: raise OLMoConfigurationError( "'vsl_curriculum' is only a valid field for VSL datasets" ) - dataset = NumpyFSLDataset( - *paths, - sequence_length=self.sequence_length, - max_target_sequence_length=self.max_target_sequence_length, - pad_token_id=self.tokenizer.pad_token_id, - eos_token_id=self.tokenizer.eos_token_id, - vocab_size=self.tokenizer.vocab_size, - dtype=self.get_dtype(), - metadata=metadata, - include_instance_metadata=self.include_instance_metadata, - generate_doc_lengths=self.generate_doc_lengths, - ) + if self.source_mixture_config: + mixture = self.source_mixture_config.build() + return NumpyFSLDatasetMixture( + *mixture.to_paths(), + sequence_length=self.sequence_length, + max_target_sequence_length=self.max_target_sequence_length, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + vocab_size=self.tokenizer.vocab_size, + dtype=self.get_dtype(), + metadata=self.metadata, + include_instance_metadata=self.include_instance_metadata, + generate_doc_lengths=self.generate_doc_lengths, + path_offset_index=mixture.to_index(), + bust_index_cache=self.bust_index_cache, + ) + else: + dataset = NumpyFSLDataset( + *paths, + sequence_length=self.sequence_length, + max_target_sequence_length=self.max_target_sequence_length, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + vocab_size=self.tokenizer.vocab_size, + dtype=self.get_dtype(), + metadata=metadata, + include_instance_metadata=self.include_instance_metadata, + generate_doc_lengths=self.generate_doc_lengths, + ) elif self.name == NumpyDatasetType.padded_fsl: if self.sequence_length is None: raise OLMoConfigurationError("'sequence_length' is required for padded FSL dataset") @@ -1608,246 +1790,3 @@ def build(self) -> NumpyDatasetBase: dataset.work_dir = Path(self.work_dir) return dataset - - -class NumpyFSLDatasetMixture(NumpyFSLDataset): - def __init__( - self, - *paths: PathOrStr, - path_offset_index: Dict[Tuple[str, int], int], - sequence_length: int, - pad_token_id: int, - eos_token_id: int, - vocab_size: int, - dtype: SupportedDType = np.uint16, - metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, - include_instance_metadata: Optional[bool] = None, - generate_doc_lengths: bool = False, - max_target_sequence_length: Optional[int] = None, - bust_index_cache: bool = False, - ): - if max_target_sequence_length is not None and ( - max_target_sequence_length < sequence_length - or max_target_sequence_length % sequence_length != 0 - ): - raise OLMoConfigurationError( - "'max_target_sequence_length' should be a multiple of 'sequence_length'" - ) - - if include_instance_metadata is None and metadata: - include_instance_metadata = True - - if isinstance(metadata, list): - if len(metadata) != len(paths): - raise OLMoConfigurationError( - "'metadata' should have the same length as the number of file paths" - ) - else: - metadata = [metadata or {}] * len(paths) - - super().__init__( - *paths, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - vocab_size=vocab_size, - dtype=dtype, - sequence_length=sequence_length, - metadata=metadata, - include_instance_metadata=include_instance_metadata, - generate_doc_lengths=generate_doc_lengths, - max_target_sequence_length=max_target_sequence_length, - ) - self._metadata = tuple(metadata) - self._include_instance_metadata = include_instance_metadata - self._num_instances: Optional[int] = None - self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None - self._lengths_dtype: Optional[SupportedDType] = None - self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None - self._path_offset_index = path_offset_index - self._bust_index_cache = bust_index_cache - - def prepare(self): - if self.fs_local_rank == 0: - log.info("Gathering indices...") - self._write_document_indices() - barrier() - len(self) - - def _get_indices_path(self, path: PathOrStr) -> Path: - sha256_hash = hashlib.sha256() - sha256_hash.update(str(path).encode()) - sha256_hash.update(str(self._get_file_size(path)).encode()) - path_hash = sha256_hash.hexdigest() - return ( - self.work_dir - / "dataset-common" - / f"mixture-instance-indices-{self.sequence_length}-{path_hash}.npy" - ) - - def _write_document_indices(self): - paths_needed: List[Tuple[PathOrStr, int]] = [] - for idx, path in enumerate(self.paths): - indices_path = self._get_indices_path(path) - if indices_path.is_file() and not self._bust_index_cache: - log.info(f"Reusing document indices for '{path}' at:\n'{indices_path}'") - elif path not in paths_needed: - paths_needed.append((path, idx)) - - if paths_needed: - with concurrent.futures.ProcessPoolExecutor() as executor: - futures = [] - for path, idx in paths_needed: - indices_path = self._get_indices_path(path) - log.info(f"Gathering instance indices for '{path}'...") - # NOTE: We limit the number of instances by total target token count // sequence length - max_instances = ( - self._path_offset_index[(str(path), idx)] // self.sequence_length - ) - future = executor.submit( - run_worker_func, - segment_documents_into_instances, - path, - indices_path, - max_sequence_length=self.sequence_length, - eos_token_id=self.eos_token_id, - dtype=self.dtype, - indices_dtype=self.dtype, - max_instances=max_instances, - ) - futures.append(future) - - concurrent.futures.wait(futures, return_when="ALL_COMPLETED") - - # Log results. - for path, future in zip([item[0] for item in paths_needed], futures): - _, total_instances = future.result() - log.info( - f"Created {total_instances:,d} instances of sequence length up to " - f"{self.sequence_length} from '{path}'" - ) - - def _get_file_size_and_length( - self, item: Tuple[PathOrStr, int], dtype: Optional[SupportedDType] = None - ) -> Tuple[int, int]: - path, idx = item - dtype = dtype or self.dtype - item_size = dtype(0).itemsize - file_size = self._get_size_from_offset_index(item) - if ( - self.max_target_sequence_length is None - or self.max_target_sequence_length == self.sequence_length - ): - return file_size, file_size // (item_size * self.sequence_length) - elif self.max_target_sequence_length > self.sequence_length: - num_max_seq_len_instances = file_size // (item_size * self.max_target_sequence_length) - return ( - file_size, - num_max_seq_len_instances - * (self.max_target_sequence_length // self.sequence_length), - ) - else: - raise RuntimeError("invalid 'max_target_sequence_length' or 'sequence_length'") - - def _get_size_from_offset_index(self, path_index: Tuple[PathOrStr, int]) -> int: - try: - path, idx = path_index - # Get size in bytes from tokens in the supplied index * itemsize - return self._path_offset_index[(str(path), idx)] * self.dtype(0).itemsize - except KeyError: - raise OLMoEnvironmentError(f"Item not found in path index @ {path_index}") - - -@dataclass -class NumpyFSLDatasetMixtureConfig(Config): - """ - A config class for easily building :class:`NumpyFSLDatasetMixture` class. - This is a special case of :class:`NumpyFSLDataset` that is built from a mixture of source - datasets based on a source mixture configuration. - """ - - source_mixture_config: SourceMixtureDatasetConfig - """ - The source mixture dataset config. - """ - tokenizer: TokenizerConfig - """ - The tokenizer config. - """ - sequence_length: Optional[int] = None - """ - The sequence length for a :class:`NumpyFSLDataset`. - """ - max_target_sequence_length: Optional[int] = None - """ - The max target sequene length for a :class:`NumpyFSLDataset`. - """ - dtype: Optional[NumpyDatasetDType] = None - """ - The numpy datatype of the token ID arrays. - """ - metadata: Optional[List[Dict[str, Any]]] = None - """ - Metadata for the numpy arrays. - """ - include_instance_metadata: bool = False - """ - Whether or not to include the :data:`metadata` in the instances returned from - :meth:`NumpyDatasetBase.__getitem__()`. - """ - generate_doc_lengths: bool = False - """ - Include individual document lengths in the instances returned from - :meth:`NumpyDatasetBase.__getitem__()`. - """ - work_dir: Optional[str] = None - """ - The dataset working directory. This is used to cache working files like shuffled indices, - instance buckets, etc. - - .. tip:: - You can save a lot of time and disk space by setting this to a common directory across - all of you runs. - """ - bust_index_cache: bool = False - """ - Whether or not to bust the index cache. - """ - - def get_dtype( - self, - ) -> SupportedDType: - if self.dtype is not None: - return NumpyDatasetDType(self.dtype).as_np_dtype() - - # Guess based on vocab size. - for dtype in ( - NumpyDatasetDType.uint8, - NumpyDatasetDType.uint16, - NumpyDatasetDType.uint32, - NumpyDatasetDType.uint64, - ): - if (self.tokenizer.vocab_size - 1) <= np.iinfo(dtype.as_np_dtype()).max: - log.info(f"Assuming dtype '{dtype}' based on vocab size") - return dtype.as_np_dtype() - - raise ValueError("vocab size too big!") - - def build(self) -> NumpyFSLDataset: - """ - Construct the corresponding :class:`NumpyFSLDatasetMixture`. - """ - mixture = self.source_mixture_config.build() - return NumpyFSLDatasetMixture( - *mixture.to_paths(), - sequence_length=self.sequence_length or 1024, - max_target_sequence_length=self.max_target_sequence_length, - pad_token_id=self.tokenizer.pad_token_id, - eos_token_id=self.tokenizer.eos_token_id, - vocab_size=self.tokenizer.vocab_size, - dtype=self.get_dtype(), - metadata=self.metadata, - include_instance_metadata=self.include_instance_metadata, - generate_doc_lengths=self.generate_doc_lengths, - path_offset_index=mixture.to_index(), - bust_index_cache=self.bust_index_cache, - ) diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py index 2d3d07e0..426158b5 100644 --- a/src/olmo_core/data/source_mixture.py +++ b/src/olmo_core/data/source_mixture.py @@ -4,9 +4,10 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from itertools import chain -from pprint import pprint from typing import Dict, List, Optional, Tuple +from rich.console import Console +from rich.table import Table import tabulate from tqdm import tqdm @@ -22,22 +23,8 @@ "SourceMixtureDatasetConfig", ] -logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) -# Disable some noisy loggers -for name in logging.Logger.manager.loggerDict.keys(): - if name in ( - "boto", - "urllib3", - "s3transfer", - "boto3", - "botocore", - "aiobotocore", - "nose", - ): - logging.getLogger(name).setLevel(logging.CRITICAL) - @dataclass class SourceMixtureConfig(Config): @@ -79,9 +66,9 @@ def for_table(self, max_tokens: int) -> Dict: "source_name": self.config.source_name, "source_population": f"{self.population:.2e}", "num_sampled": f"{self.num_selected:.2e}", - "target_ratio": self.config.target_ratio, - "max_repetion_ratio": self.config.max_repetition_ratio, - "max_source_fraction": self.config.max_source_fraction, + "target_ratio": str(self.config.target_ratio), + "max_repetion_ratio": str(self.config.max_repetition_ratio), + "max_source_fraction": str(self.config.max_source_fraction), "observed_source_ratio": f"{(self.num_selected / self.population):.4}", "observed_global_ratio": f"{(self.num_selected / max_tokens):.4}", } @@ -156,10 +143,9 @@ def build(self) -> SourceMixtureDataset: random.seed(self.seed) available_tokens_by_source: Dict[str, int] = {} - print("---------------------------------------------------------") - print("Generating a source mixture from configurations:") - for source_config in self.source_configs: - pprint(source_config) + log.info("---------------------------------------------------------") + log.info("Generating a source mixture from configurations:") + log.info(self.source_configs) # Count the number of tokens available for each source for source_config in self.source_configs: @@ -205,33 +191,7 @@ def build(self) -> SourceMixtureDataset: ) ) - log.info("Outcome by source => ") - print( - tabulate.tabulate( - [item.for_table(self.max_tokens) for item in tokens_details_by_source], - headers="keys", - tablefmt="pretty", - ), - ) - - total_tokens = sum([item.population for item in tokens_details_by_source]) - selected_tokens = sum([item.num_selected for item in tokens_details_by_source]) - observed_global_ratio = selected_tokens / total_tokens - - log.info("Global outcome => ") - print( - tabulate.tabulate( - [ - { - "total_tokens": f"{total_tokens:.2e}", - "selected_tokens": f"{selected_tokens:.2e}", - "observed_global_ratio": f"{observed_global_ratio:.4}", - } - ], - tablefmt="pretty", - headers="keys", - ), - ) + self.render_mixture_outcome_tables(tokens_details_by_source) for outcome in completed: for item in outcome.path_tokens: @@ -305,3 +265,39 @@ def _bytes_to_tokens(self, num_bytes: int, dtype: NumpyDatasetDType) -> int: """ npdtype = dtype.as_np_dtype() return num_bytes // npdtype(int(0)).itemsize + + def render_mixture_outcome_tables(self, results: List[SourceTokenDetails]) -> None: + """ + Render tables enumerating the global and per-source mixture outcomes. + """ + + console = Console() + + source_rows = [item.for_table(self.max_tokens) for item in results] + source_headers = source_rows[0].keys() + + source_table = Table(title="Outcome by source") + for header in source_headers: + source_table.add_column(header) + + for row in source_rows: + source_table.add_row(*[row[header] for header in source_headers]) + + console.print(source_table) + + total_tokens = sum([item.population for item in results]) + selected_tokens = sum([item.num_selected for item in results]) + observed_global_ratio = f"{(selected_tokens / total_tokens):.4}" + + global_table = Table(title="Global outcome") + global_headers = [ + "total_tokens", + "selected_tokens", + "observed_global_ratio", + ] + + for header in global_headers: + global_table.add_column(header) + + global_table.add_row(f"{total_tokens:.2e}", f"{selected_tokens:.2e}", observed_global_ratio) + console.print(global_table) diff --git a/src/olmo_core/data/types.py b/src/olmo_core/data/types.py index 2b814b28..d08571f2 100644 --- a/src/olmo_core/data/types.py +++ b/src/olmo_core/data/types.py @@ -4,7 +4,7 @@ from olmo_core.config import StrEnum -SupportedDType = Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] +NumpyUIntTypes = Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] class NumpyDatasetType(StrEnum): diff --git a/src/test/data/data_loader_test.py b/src/test/data/data_loader_test.py index 8edaddda..f552f86e 100644 --- a/src/test/data/data_loader_test.py +++ b/src/test/data/data_loader_test.py @@ -187,63 +187,6 @@ def test_fsl_data_loader_multiple_epochs( assert data_loader.tokens_processed == 0 -@pytest.mark.parametrize( - "num_tokens, sequence_length, world_size, num_workers, num_threads, batch_size", - [ - (100, 4, 2, 2, 2, 8), # 2 instances per batch, 12 instances total - ], -) -def test_fsl_data_loader_with_mixture( - tmp_path: Path, - num_tokens: int, - sequence_length: int, - world_size: int, - num_workers: int, - num_threads: int, - batch_size: int, # in tokens -): - dataset = get_fsl_mixture(tmp_path, sequence_length=sequence_length, num_tokens=num_tokens) - assert batch_size % sequence_length == 0 - assert batch_size % world_size == 0 - rank_batch_size = batch_size // world_size - assert rank_batch_size > 0 - num_batches = num_tokens // batch_size - - def get_all_batches() -> List[List[int]]: - all_batches: List[List[int]] = [[] for _ in range(num_batches)] - for rank in range(world_size): - data_loader = NumpyFSLDataLoader( - dataset, - global_batch_size=batch_size, - collator=DataCollator(pad_token_id=-1), - shuffle=False, - num_threads=num_threads, - work_dir=tmp_path, - dp_rank=rank, - dp_world_size=world_size, - num_workers=num_workers, - ) - data_loader.reshuffle(epoch=1) - batches = list(data_loader) - assert len(batches) == num_batches - for i, batch in enumerate(batches): - for instance in batch["input_ids"]: - all_batches[i].extend(instance.tolist()) - return all_batches - - all_batches = get_all_batches() - all_tokens = [] - assert len(all_batches) == num_batches - for batch in all_batches: - assert len(batch) == batch_size - all_tokens.extend(batch) - - ds_tokens = list(chain.from_iterable([i["input_ids"].tolist() for i in dataset])) - - assert len(all_tokens) == num_batches * batch_size - assert set(all_tokens) == set(ds_tokens) - - @pytest.mark.parametrize( "shuffle", [pytest.param(True, id="shuffle"), pytest.param(False, id="no-shuffle")] ) diff --git a/src/test/data/fixtures.py b/src/test/data/fixtures.py index b7590c3b..941445c7 100644 --- a/src/test/data/fixtures.py +++ b/src/test/data/fixtures.py @@ -4,8 +4,9 @@ import numpy as np from olmo_core.data import ( + NumpyDatasetBase, NumpyFSLDataset, - NumpyFSLDatasetMixtureConfig, + NumpyDatasetConfig, TokenizerConfig, ) from olmo_core.data.source_mixture import ( @@ -24,13 +25,13 @@ def get_fsl_mixture( sequence_length: int = 4, num_tokens: int = 20 * 1000, eos: int = 0, -) -> NumpyFSLDataset: +) -> NumpyDatasetBase: seed = 42 mmap1 = mk_mmaps( - tmp_path, "mmap1", 1, num_tokens // 2, dtype, eos=eos, seed=seed, seq_length=sequence_length + tmp_path, "mmap1", 1, num_tokens * 2, dtype, eos=eos, seed=seed, seq_length=sequence_length ) mmap2 = mk_mmaps( - tmp_path, "mmap2", 1, num_tokens // 2, dtype, eos=eos, seed=seed, seq_length=sequence_length + tmp_path, "mmap2", 1, num_tokens * 2, dtype, eos=eos, seed=seed, seq_length=sequence_length ) tokenizer = TokenizerConfig( @@ -59,7 +60,7 @@ def get_fsl_mixture( seed=seed, ) - ds = NumpyFSLDatasetMixtureConfig( + ds = NumpyDatasetConfig( source_mixture_config=mixture_config, sequence_length=sequence_length, tokenizer=tokenizer, diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index 11b62d4b..8c26f2fb 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -6,7 +6,6 @@ from olmo_core.data import ( NumpyDatasetConfig, NumpyFSLDataset, - NumpyFSLDatasetMixtureConfig, NumpyPaddedFSLDataset, NumpyVSLDataset, TokenizerConfig, @@ -107,7 +106,7 @@ def test_numpy_fsl_mixture_dataset(tmp_path: Path): seed=seed, ) - ds = NumpyFSLDatasetMixtureConfig( + ds = NumpyDatasetConfig( source_mixture_config=mixture_config, sequence_length=sequence_length, tokenizer=tokenizer, @@ -126,7 +125,7 @@ def test_numpy_fsl_mixture_dataset(tmp_path: Path): 15795, 52202, ] # stable because we pass a seed - assert ds.num_tokens == 10000 + # assert ds.num_tokens == 10000 assert len(ds) == 2500 @@ -166,7 +165,7 @@ def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path): seed=seed, ) - ds = NumpyFSLDatasetMixtureConfig( + ds = NumpyDatasetConfig( source_mixture_config=mixture_config, sequence_length=sequence_length, tokenizer=tokenizer, @@ -185,7 +184,7 @@ def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path): 15795, 52202, ] # stable because we pass a seed - assert ds.num_tokens == 10000 + # assert ds.num_tokens == 10000 assert len(ds) == 2500 diff --git a/src/test/data/source_mixture_test.py b/src/test/data/source_mixture_test.py index 743fd182..87e2b4a0 100644 --- a/src/test/data/source_mixture_test.py +++ b/src/test/data/source_mixture_test.py @@ -1,3 +1,4 @@ +import logging from itertools import chain from pathlib import Path @@ -14,7 +15,7 @@ from ..utils import mk_mmaps -def test_source_mixture_config(tmp_path: Path, capsys): +def test_source_mixture_config(tmp_path: Path, caplog, capsys): source_paths = { "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=2, size=1_000_000), "2": mk_mmaps(tmp_path=tmp_path, prefix="source2", num_files=2, size=1_000_000), @@ -46,9 +47,12 @@ def test_source_mixture_config(tmp_path: Path, capsys): sequence_length=1024, ) - with capsys.disabled(): - print("\n") + # NOTE: We need to disable capsys so we can override log capture as + # we want to see the rendered tables in the case + with capsys.disabled(), caplog.at_level(logging.DEBUG): + config.validate() mixture = config.build() + print(caplog.text) assert isinstance(mixture, SourceMixtureDataset) From a288d9e50a058089a427282246539405bfbe30da Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:16:26 -0700 Subject: [PATCH 43/57] Drop examples --- src/examples/train_with_mixture.py | 241 ---------------------- src/examples/train_with_mixture_launch.py | 66 ------ 2 files changed, 307 deletions(-) delete mode 100644 src/examples/train_with_mixture.py delete mode 100644 src/examples/train_with_mixture_launch.py diff --git a/src/examples/train_with_mixture.py b/src/examples/train_with_mixture.py deleted file mode 100644 index 4451e5c2..00000000 --- a/src/examples/train_with_mixture.py +++ /dev/null @@ -1,241 +0,0 @@ -""" -Example of how to train a transformer language model with a source mixture config. - -Launch this with torchrun: - - torchrun --nproc-per-node=4 src/examples/train_with_mixture.py run_name -""" - -import sys -from dataclasses import dataclass -from typing import cast - -import s3fs -from torch.distributed.elastic.multiprocessing.errors import record - -from olmo_core.config import Config, DType -from olmo_core.data import ( - NumpyDataLoaderConfig, - NumpyDatasetConfig, - NumpyDatasetType, - TokenizerConfig, -) -from olmo_core.data.source_mixture import ( - SourceMixtureConfig, - SourceMixtureDatasetConfig, -) -from olmo_core.data.types import NumpyDatasetDType -from olmo_core.distributed.parallel import DataParallelType -from olmo_core.distributed.utils import init_hybrid_shard_mesh -from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig -from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride -from olmo_core.train import ( - Duration, - TrainerConfig, - prepare_training_environment, - teardown_training_environment, -) -from olmo_core.train.callbacks import ( - CheckpointerCallback, - CometCallback, - ConfigSaverCallback, - GPUMemoryMonitorCallback, - GradClipperCallback, - LMEvaluatorCallbackConfig, - ProfilerCallback, - SchedulerCallback, - SequenceLengthSchedulerCallback, - WandBCallback, -) -from olmo_core.utils import get_default_device, seed_all - - -@dataclass -class ExperimentConfig(Config): - model: TransformerConfig - optim: AdamWConfig - dataset: NumpyDatasetConfig - data_loader: NumpyDataLoaderConfig - trainer: TrainerConfig - init_seed: int = 12536 - - -def build_config(run_name: str) -> ExperimentConfig: - tokenizer_config = TokenizerConfig.dolma2() - - model_config = TransformerConfig.llama2_271M( - # a little bigger than actual vocab size to make it a multiple of 128 - vocab_size=tokenizer_config.padded_vocab_size(), - compile=True, - dp_config=TransformerDataParallelConfig( - name=DataParallelType.fsdp, - param_dtype=DType.bfloat16, - reduce_dtype=DType.float32, - ), - ) - - optim_config = AdamWConfig( - lr=1e-3, - group_overrides=[ - OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) - ], - ) - - s3 = s3fs.S3FileSystem() - - # DCLM docs + rewrites - baseline = s3.glob( - "s3://ai2-llm/preprocessed/dclm/samples/src-100b/**/allenai/dolma2-tokenizer/*.npy" - ) - rewrites = s3.glob( - "s3://ai2-llm/preprocessed/dclm/samples/rewrite-100b/**/allenai/dolma2-tokenizer/*.npy" - ) - - sequence_length = 1024 - source_config = SourceMixtureDatasetConfig( - max_tokens=int(10e7), # 100M tokens - sequence_length=sequence_length, - source_configs=[ - SourceMixtureConfig( - paths=[f"s3://{path}" for path in baseline], - source_name="baseline", - max_repetition_ratio=1.0, # 1.0 is default but here to illustrate options - target_ratio=0.8, - ), - SourceMixtureConfig( - source_name="rewrites", - paths=[f"s3://{path}" for path in rewrites], - target_ratio=0.2, - ), - ], - processes=10, - dtype=NumpyDatasetDType.uint32, - seed=42, - ) - - dataset_config = NumpyDatasetConfig( - source_mixture_config=source_config, - sequence_length=sequence_length, - max_target_sequence_length=8192, - tokenizer=TokenizerConfig.dolma2(), - work_dir="/tmp/dataset-cache", - bust_index_cache=True, - ) - - data_loader_config = NumpyDataLoaderConfig( - global_batch_size=256 * sequence_length, - seed=0, - num_workers=4, - ) - - trainer_config = ( - TrainerConfig( - save_folder=f"/tmp/{run_name}", - rank_microbatch_size=16 * sequence_length, - save_overwrite=True, - metrics_collect_interval=5, - cancel_check_interval=5, - ) - .with_callback("lr_scheduler", SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=100))) - .with_callback( - "seq_len_scheduler", - SequenceLengthSchedulerCallback( - min_sequence_length=128, warmup_steps=100, enabled=False - ), - ) - .with_callback("gpu_monitor", GPUMemoryMonitorCallback()) - .with_callback("grad_clipper", GradClipperCallback(max_grad_norm=1.0)) - .with_callback( - "checkpointer", - CheckpointerCallback( - save_interval=1000, - ephemeral_save_interval=100, - save_async=True, - ), - ) - .with_callback( - "comet", - CometCallback( - name=run_name, - cancel_check_interval=10, - enabled=False, # change to true to enable - ), - ) - .with_callback( - "wandb", - WandBCallback( - name=run_name, - cancel_check_interval=10, - enabled=False, # change to true to enable - ), - ) - .with_callback("config_saver", ConfigSaverCallback()) - .with_callback("profiler", ProfilerCallback(enabled=False)) - .with_callback( - "evaluator", - LMEvaluatorCallbackConfig( - eval_dataset=NumpyDatasetConfig( - paths=[ - "s3://ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy" - ], - metadata=[{"label": "c4-validation"}], - name=NumpyDatasetType.padded_fsl, - sequence_length=sequence_length, - tokenizer=tokenizer_config, - work_dir="/tmp/dataset-cache", - ), - eval_interval=250, - eval_duration=Duration.steps(10), - ), - ) - ) - - return ExperimentConfig( - model=model_config, - optim=optim_config, - dataset=dataset_config, - data_loader=data_loader_config, - trainer=trainer_config, - ) - - -@record -def main(run_name: str): - config = build_config(run_name) - - # Set RNG states on all devices. - seed_all(config.init_seed) - - # Build components. - model = config.model.build( - init_device="meta", - device=get_default_device(), - dp_mesh=init_hybrid_shard_mesh(num_replicas=2), - ) - optim = config.optim.build(model) - dataset = config.dataset.build() - data_loader = config.data_loader.build(dataset) - trainer = config.trainer.build(model, optim, data_loader) - - # Save config to W&B and each checkpoint dir. - config_dict = config.as_config_dict() - cast(CometCallback, trainer.callbacks["comet"]).config = config_dict - cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict - cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict - - # Train. - trainer.fit() - - -if __name__ == "__main__": - if len(sys.argv) < 2: - print(f"Usage: python {sys.argv[0]} run_name") - sys.exit(1) - - run_name = sys.argv[1] - - prepare_training_environment() - try: - main(run_name) - finally: - teardown_training_environment() diff --git a/src/examples/train_with_mixture_launch.py b/src/examples/train_with_mixture_launch.py deleted file mode 100644 index dd6899a9..00000000 --- a/src/examples/train_with_mixture_launch.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -An example of how to launch the training script on Beaker. -Run this with: - - python src/examples/train_with_mixture_launch.py run_name [OVERRIDES...] -""" - -import sys - -from beaker import Beaker - -from olmo_core.launch.beaker import BeakerEnvSecret, BeakerLaunchConfig -from olmo_core.utils import generate_uuid, prepare_cli_environment - - -def build_config(run_name: str) -> BeakerLaunchConfig: - beaker_user = (Beaker.from_env().account.whoami().name).upper() - return BeakerLaunchConfig( - name=f"olmo-core-test-{generate_uuid()[:8]}", - budget="ai2/oe-training", - cmd=["src/examples/train_with_mixture.py", run_name], - task_name="train", - workspace="ai2/OLMo-core", - description="Testing OLMo-core launch utilities", - clusters=["ai2/allennlp-cirrascale"], - env_secrets=[ - BeakerEnvSecret(name="BEAKER_TOKEN", secret=f"{beaker_user}_BEAKER_TOKEN"), - BeakerEnvSecret(name="WANDB_API_KEY", secret=f"{beaker_user}_WANDB_API_KEY"), - BeakerEnvSecret(name="COMET_API_KEY", secret=f"{beaker_user}_COMET_API_KEY"), - BeakerEnvSecret(name="AWS_CONFIG", secret=f"{beaker_user}_AWS_CONFIG"), - BeakerEnvSecret(name="AWS_CREDENTIALS", secret=f"{beaker_user}_AWS_CREDENTIALS"), - BeakerEnvSecret(name="R2_ENDPOINT_URL", secret="R2_ENDPOINT_URL"), - BeakerEnvSecret(name="WEKA_ENDPOINT_URL", secret="WEKA_ENDPOINT_URL"), - ], - setup_steps=[ - # Clone repo. - 'git clone "$REPO_URL" .', - 'git checkout "$GIT_REF"', - "git submodule update --init --recursive", - # Setup python environment. - "conda shell.bash activate base", - "pip install -e '.[all]'", - "pip freeze", - # Move AWS credentials from env to relevant files - "mkdir -p ~/.aws", - "printenv AWS_CONFIG > ~/.aws/config", - "printenv AWS_CREDENTIALS > ~/.aws/credentials", - ], - num_nodes=1, - num_gpus=4, - shared_filesystem=True, - nfs=False, - allow_dirty=True, - ) - - -if __name__ == "__main__": - if len(sys.argv) < 2: - print(f"Usage: python {sys.argv[0]} run_name [OVERRIDES...]") - sys.exit(1) - - run_name = sys.argv[1] - - prepare_cli_environment() - - build_config(run_name).launch(follow=True) From 9c49f2573614cf0ce812ade00f572d67ea5d8dc2 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:32:01 -0700 Subject: [PATCH 44/57] A bit more cleanup --- src/olmo_core/data/source_mixture.py | 25 ++++++++++++------------- src/test/data/data_loader_test.py | 3 --- src/test/data/fixtures.py | 7 +------ 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py index 426158b5..e962bae9 100644 --- a/src/olmo_core/data/source_mixture.py +++ b/src/olmo_core/data/source_mixture.py @@ -7,9 +7,8 @@ from typing import Dict, List, Optional, Tuple from rich.console import Console +from rich.progress import Progress from rich.table import Table -import tabulate -from tqdm import tqdm from olmo_core.aliases import PathOrStr from olmo_core.config import Config @@ -65,7 +64,7 @@ def for_table(self, max_tokens: int) -> Dict: return { "source_name": self.config.source_name, "source_population": f"{self.population:.2e}", - "num_sampled": f"{self.num_selected:.2e}", + "num_selected": f"{self.num_selected:.2e}", "target_ratio": str(self.config.target_ratio), "max_repetion_ratio": str(self.config.max_repetition_ratio), "max_source_fraction": str(self.config.max_source_fraction), @@ -245,16 +244,16 @@ def _count_tokens_for_paths(self, paths: List[PathOrStr], source: Optional[str]) for path in paths: futures.append(executor.submit(self._count_tokens_for_file, path)) - return sum( - [ - future.result() - for future in tqdm( - as_completed(futures), - total=len(futures), - desc=f"Counting tokens {'for ' + source if source else ''}", - ) - ] - ) + with Progress() as progress: + results = [] + task = progress.add_task( + f"Counting available tokens for source: {source}", total=len(futures) + ) + for future in as_completed(futures): + progress.update(task, advance=1) + results.append(future.result()) + + return sum(results) def _count_tokens_for_file(self, path: PathOrStr) -> int: return self._bytes_to_tokens(get_file_size(path), self.dtype) diff --git a/src/test/data/data_loader_test.py b/src/test/data/data_loader_test.py index f552f86e..3858b605 100644 --- a/src/test/data/data_loader_test.py +++ b/src/test/data/data_loader_test.py @@ -1,4 +1,3 @@ -from itertools import chain from pathlib import Path from typing import List @@ -16,8 +15,6 @@ VSLNaturalCurriculum, ) -from .fixtures import get_fsl_mixture - @pytest.mark.parametrize( "num_tokens, sequence_length, world_size, num_workers, num_threads, batch_size", diff --git a/src/test/data/fixtures.py b/src/test/data/fixtures.py index 941445c7..a3f99c12 100644 --- a/src/test/data/fixtures.py +++ b/src/test/data/fixtures.py @@ -3,12 +3,7 @@ import numpy as np -from olmo_core.data import ( - NumpyDatasetBase, - NumpyFSLDataset, - NumpyDatasetConfig, - TokenizerConfig, -) +from olmo_core.data import NumpyDatasetBase, NumpyDatasetConfig, TokenizerConfig from olmo_core.data.source_mixture import ( SourceMixtureConfig, SourceMixtureDatasetConfig, From 89504bca26415a2013ed8aeb668c20e3c92a316e Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:33:07 -0700 Subject: [PATCH 45/57] Outdated changelog --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38e4872c..5d98d782 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `SourceMixtureDataset` for composing a training mixture based on ratios of source datasets. - Added `NumpyFSLDatasetMixture` for constructing a `NumpyDatasetBase` from a `SourceMixtureDataset`. Note this is only supported for FSL datasets. - Added tests for `SourceMixture*` and `NumpyFSLDatasetMixture`. -- Added example launch script for training a model using a `NumpyFSLDatasetMixture`. ### Changed - Moved some types into `olmo_core.data.types` to avoid some circular dependencies. From 3aa5c3512ff0b68018d4d2bc84342564c009eb13 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:33:54 -0700 Subject: [PATCH 46/57] Unused deps --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8f24e277..d048a2a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,6 @@ dependencies = [ "omegaconf", "safetensors", "importlib_resources", - "tabulate", - "tqdm", ] [project.urls] From 8f729dddf3ff2ea4abcf8278b4e5840d0526302c Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:34:20 -0700 Subject: [PATCH 47/57] One more dep --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d048a2a5..5af9c589 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ dev = [ "sphinx-copybutton==0.5.2", "sphinx-autobuild==2021.3.14", "sphinx-autodoc-typehints==1.23.3", - "s3fs", ] beaker = [ "beaker-py", From a8481954c30e36d7dac0704a52b57eb7d22ff0b6 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:38:59 -0700 Subject: [PATCH 48/57] uncomment test assertions --- src/olmo_core/data/utils.py | 2 +- src/test/data/numpy_dataset_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index cc74deef..2d1819da 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -423,7 +423,7 @@ def segment_documents_into_instances( total_og_docs = 0 indices: List[int] = [] for start_idx, end_idx in iter_document_indices(path, eos_token_id=eos_token_id, dtype=dtype): - if max_instances and len(indices) // 2 >= max_instances: + if max_instances is not None and len(indices) // 2 >= max_instances: break total_og_docs += 1 length = end_idx - start_idx diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index 8c26f2fb..0ab61104 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -125,7 +125,7 @@ def test_numpy_fsl_mixture_dataset(tmp_path: Path): 15795, 52202, ] # stable because we pass a seed - # assert ds.num_tokens == 10000 + assert ds.num_tokens == 10000 assert len(ds) == 2500 @@ -184,7 +184,7 @@ def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path): 15795, 52202, ] # stable because we pass a seed - # assert ds.num_tokens == 10000 + assert ds.num_tokens == 10000 assert len(ds) == 2500 From 5322bf1a889037b0590b88ab8f53bc3a0b926943 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:41:03 -0700 Subject: [PATCH 49/57] Drop todo --- src/test/data/source_mixture_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test/data/source_mixture_test.py b/src/test/data/source_mixture_test.py index 87e2b4a0..264ff52f 100644 --- a/src/test/data/source_mixture_test.py +++ b/src/test/data/source_mixture_test.py @@ -271,7 +271,6 @@ def test_dataset_mixture_build_insufficient_source_max_fraction(tmp_path: Path): config.build() -# TODO: Handle duplicate paths in source mixture def test_dataset_mixture_build_duplicate_paths(tmp_path: Path): sources = { "1": mk_mmaps(tmp_path=tmp_path, prefix="source1", num_files=1, size=500_000), From 5c226651f8bfcc03fa158b0ac394e140b3bd4d0d Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Mon, 28 Oct 2024 11:45:06 -0700 Subject: [PATCH 50/57] 0 is an invalid token --- src/test/data/numpy_dataset_test.py | 16 ++++++++-------- src/test/utils.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index 0ab61104..5d3a7d1e 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -120,10 +120,10 @@ def test_numpy_fsl_mixture_dataset(tmp_path: Path): expected ), f"Fingerprint mismatch, expected {expected}, got {ds.fingerprint[-6:]}...Do you need to update expected fingerprint?" assert ds[0]["input_ids"].tolist() == [ - 56422, - 24545, - 15795, - 52202, + 56423, + 24546, + 15796, + 52203, ] # stable because we pass a seed assert ds.num_tokens == 10000 assert len(ds) == 2500 @@ -179,10 +179,10 @@ def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path): expected ), f"Fingerprint mismatch, expected {expected}, got {ds.fingerprint[-6:]}...Do you need to update expected fingerprint?" assert ds[0]["input_ids"].tolist() == [ - 56422, - 24545, - 15795, - 52202, + 56423, + 24546, + 15796, + 52203, ] # stable because we pass a seed assert ds.num_tokens == 10000 assert len(ds) == 2500 diff --git a/src/test/utils.py b/src/test/utils.py index d09e6e57..14a622aa 100644 --- a/src/test/utils.py +++ b/src/test/utils.py @@ -113,7 +113,7 @@ def mk_mmaps( for i in range(num_files): filepath = f"{tmp_path}/{prefix}_{i}.npy" np.random.seed(seed) - data = np.random.randint(0, np.iinfo(dtype).max, size=size, dtype=dtype) + data = np.random.randint(1, np.iinfo(dtype).max, size=size, dtype=dtype) data = np.append( np.insert(data, np.arange(seq_length + 1, len(data), seq_length), eos), eos ) From 87e9168c3d6f982c63f0b311e54f53d92bed5899 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Tue, 29 Oct 2024 10:48:05 -0700 Subject: [PATCH 51/57] More feedback --- src/olmo_core/data/numpy_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index f2764818..1e04659f 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -508,7 +508,7 @@ def _get_file_size_and_length( class NumpyFSLDatasetMixture(NumpyFSLDataset): """ - A version of :class:`NumpyFSLDataset` built from a mixture of sources and their expected token ratios relative to each other. A path_offset_index is used to determine the number of instances to retain from a path when constructing the local indices. + A version of :class:`NumpyFSLDataset` built from a mixture of sources and their expected token ratios relative to each other. A ``path_offset_index`` is used to determine the number of instances to retain from a path when constructing the local indices. """ def __init__( @@ -1652,7 +1652,7 @@ def build(self) -> NumpyDatasetBase: paths.extend(matches) elif self.paths: paths = self.paths - elif self.source_mixture_config: + elif self.source_mixture_config and self.name == NumpyDatasetType.fsl: log.info("Building dataset from source mixture...") else: assert self.mix is not None From fe50a32cf07a4ec9a7d64ffbc0e87cdf6a7ef31e Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Tue, 29 Oct 2024 15:45:12 -0700 Subject: [PATCH 52/57] Randomly sample instances when segmenting --- src/olmo_core/data/numpy_dataset.py | 5 ++++- src/olmo_core/data/source_mixture.py | 3 ++- src/olmo_core/data/utils.py | 13 ++++++++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 1e04659f..2c1cd65f 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -515,6 +515,7 @@ def __init__( self, *paths: PathOrStr, path_offset_index: Dict[Tuple[str, int], int], + seed: int, sequence_length: int, pad_token_id: int, eos_token_id: int, @@ -565,6 +566,7 @@ def __init__( self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None self._path_offset_index = path_offset_index self._bust_index_cache = bust_index_cache + self._seed = seed def prepare(self): if self.fs_local_rank == 0: @@ -612,7 +614,7 @@ def _write_document_indices(self): eos_token_id=self.eos_token_id, dtype=self.dtype, indices_dtype=self.dtype, - max_instances=max_instances, + sample=(max_instances, self._seed), ) futures.append(future) @@ -1694,6 +1696,7 @@ def build(self) -> NumpyDatasetBase: mixture = self.source_mixture_config.build() return NumpyFSLDatasetMixture( *mixture.to_paths(), + seed=mixture.seed, sequence_length=self.sequence_length, max_target_sequence_length=self.max_target_sequence_length, pad_token_id=self.tokenizer.pad_token_id, diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py index e962bae9..b19294bb 100644 --- a/src/olmo_core/data/source_mixture.py +++ b/src/olmo_core/data/source_mixture.py @@ -91,6 +91,7 @@ class SourceMixtureDataset: A dataset consisting of a fractionalized mixture of data sources. """ + seed: int sources: List[SourceMixtureOutcome] def to_index(self) -> Dict[Tuple[str, int], int]: @@ -196,7 +197,7 @@ def build(self) -> SourceMixtureDataset: for item in outcome.path_tokens: log.info(f"Selected {item.tokens} tokens from {outcome.name} at {item.path}") - return SourceMixtureDataset(completed) + return SourceMixtureDataset(seed=self.seed, sources=completed) def get_paths_and_tokens_for_source( self, source_config: SourceMixtureConfig, token_details: SourceTokenDetails diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index 2d1819da..8f025d94 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -412,25 +412,32 @@ def segment_documents_into_instances( indices_dtype: Union[ Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64] ] = np.uint32, - max_instances: Optional[int] = None, + sample: Optional[Tuple[int, int]] = None, ) -> Tuple[int, int]: """ Segment documents into instances of at most ``sequence_length`` tokens. Saving the indices of the instances to ``target``. + Sample a subset of the instances if ``sample`` is provided as a tuple of ``(max_instances, seed)``. + Returns the number of original documents and the number of resulting instances documents. """ total_og_docs = 0 indices: List[int] = [] for start_idx, end_idx in iter_document_indices(path, eos_token_id=eos_token_id, dtype=dtype): - if max_instances is not None and len(indices) // 2 >= max_instances: - break total_og_docs += 1 length = end_idx - start_idx indices.append(start_idx) indices.append(start_idx + min(length, max_sequence_length)) start_idx += length + if sample is not None: + max_instances, seed = sample + rng = get_rng(seed) + indices = ( + rng.choice(np.array(indices).reshape(-1, 2).tolist(), size=max_instances).flatten() + ).tolist() + with memmap_to_write(target, dtype=indices_dtype, shape=(len(indices),)) as indices_mmap: indices_mmap[:] = indices From cffcba38b7e965b09f68de9774226cb0c0c09a86 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 30 Oct 2024 10:41:49 -0700 Subject: [PATCH 53/57] Memray + limit marker --- src/olmo_core/data/utils.py | 22 +++++++++++----------- src/test/data/utils_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index 8f025d94..dd44d3af 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -423,22 +423,22 @@ def segment_documents_into_instances( Returns the number of original documents and the number of resulting instances documents. """ total_og_docs = 0 - indices: List[int] = [] - for start_idx, end_idx in iter_document_indices(path, eos_token_id=eos_token_id, dtype=dtype): - total_og_docs += 1 - length = end_idx - start_idx - indices.append(start_idx) - indices.append(start_idx + min(length, max_sequence_length)) - start_idx += length + idx_gen = ( + idx + for start_idx, end_idx in iter_document_indices( + path, eos_token_id=eos_token_id, dtype=dtype + ) + for idx in (start_idx, start_idx + min(end_idx - start_idx, max_sequence_length)) + ) + indices = np.fromiter(idx_gen, dtype=indices_dtype) + total_og_docs = len(indices) // 2 if sample is not None: max_instances, seed = sample rng = get_rng(seed) - indices = ( - rng.choice(np.array(indices).reshape(-1, 2).tolist(), size=max_instances).flatten() - ).tolist() + indices = rng.choice(indices.reshape(-1, 2), size=max_instances).reshape(-1) - with memmap_to_write(target, dtype=indices_dtype, shape=(len(indices),)) as indices_mmap: + with memmap_to_write(target, dtype=indices_dtype, shape=(indices.size,)) as indices_mmap: indices_mmap[:] = indices return total_og_docs, len(indices) // 2 diff --git a/src/test/data/utils_test.py b/src/test/data/utils_test.py index 4d678c09..1d19b065 100644 --- a/src/test/data/utils_test.py +++ b/src/test/data/utils_test.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch from olmo_core.data.utils import ( @@ -9,9 +10,40 @@ iter_document_indices, melt_batch, write_document_indices, + segment_documents_into_instances, ) +@pytest.mark.limit_memory("11 KB") +def test_segment_documents_into_instances(tmp_path): + data = [1, 2, 3, 4, 0, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 0] * 10 + data_path = tmp_path / "data.npy" + max_sequence_length = 4 + mmap = np.memmap(data_path, mode="w+", dtype=np.uint16, shape=(len(data),)) + indices_path = tmp_path / "indices.npy" + mmap[:] = data + mmap.flush() + + eos = 0 + dtype = np.uint16 + sample = (2, 42) + + results = [] + for _ in range(10): + results.append( + segment_documents_into_instances( + path=data_path, + target=indices_path, + max_sequence_length=max_sequence_length, + eos_token_id=eos, + dtype=dtype, + sample=sample, + ) + ) + + assert all([r[1] == 2 for r in results]) + + def test_iter_document_indices(tmp_path): data = [1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 0] data_path = tmp_path / "data.npy" From 293be028d45e441e260f434e48ad72fa35dbaed3 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 30 Oct 2024 10:42:53 -0700 Subject: [PATCH 54/57] Add dep --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 5af9c589..9356a053 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dev = [ "black>=23.1,<24.0", "isort>=5.12,<5.13", "pytest", + "pytest-memray", "pytest-sphinx", "pytest-xdist", "twine>=1.11.0", From e45d2c3baff6b70bd3746d89bb289d47b7a27bf4 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 30 Oct 2024 10:43:19 -0700 Subject: [PATCH 55/57] Lint --- src/test/data/utils_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/data/utils_test.py b/src/test/data/utils_test.py index 1d19b065..b5140e42 100644 --- a/src/test/data/utils_test.py +++ b/src/test/data/utils_test.py @@ -9,8 +9,8 @@ iter_batched, iter_document_indices, melt_batch, - write_document_indices, segment_documents_into_instances, + write_document_indices, ) From ffe7660efecfbf2c36c3de57e7be8f6d1e860c38 Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 30 Oct 2024 11:10:34 -0700 Subject: [PATCH 56/57] Bigger array is more informative --- src/test/data/utils_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/data/utils_test.py b/src/test/data/utils_test.py index b5140e42..caef70ba 100644 --- a/src/test/data/utils_test.py +++ b/src/test/data/utils_test.py @@ -14,9 +14,9 @@ ) -@pytest.mark.limit_memory("11 KB") +@pytest.mark.limit_memory("245 KB") def test_segment_documents_into_instances(tmp_path): - data = [1, 2, 3, 4, 0, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 0] * 10 + data = [1, 2, 3, 4, 0, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 0] * 1000 data_path = tmp_path / "data.npy" max_sequence_length = 4 mmap = np.memmap(data_path, mode="w+", dtype=np.uint16, shape=(len(data),)) From 1fdb995302cc1d870b705e429d3709dbc7e342cb Mon Sep 17 00:00:00 2001 From: Tyler Murray Date: Wed, 30 Oct 2024 14:47:06 -0700 Subject: [PATCH 57/57] Feedback --- CHANGELOG.md | 2 +- src/olmo_core/data/numpy_dataset.py | 26 ++++-------- src/olmo_core/data/source_mixture.py | 63 +++++++++++++++++++++++++++- src/test/data/fixtures.py | 1 - src/test/data/numpy_dataset_test.py | 2 - 5 files changed, 72 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc909086..24a9ea71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,10 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `SourceMixtureDataset` for composing a training mixture based on ratios of source datasets. - Added `NumpyFSLDatasetMixture` for constructing a `NumpyDatasetBase` from a `SourceMixtureDataset`. Note this is only supported for FSL datasets. - Added tests for `SourceMixture*` and `NumpyFSLDatasetMixture`. +- Added `DownstreamEvaluatorCallbackConfig` class for running in-loop downstream eval via [OLMo-in-loop-evals](https://github.com/allenai/OLMo-in-loop-evals). ### Changed - Moved some types into `olmo_core.data.types` to avoid some circular dependencies. -- Added `DownstreamEvaluatorCallbackConfig` class for running in-loop downstream eval via [OLMo-in-loop-evals](https://github.com/allenai/OLMo-in-loop-evals). ### Removed diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 2c1cd65f..6862e7e6 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -135,7 +135,7 @@ def file_sizes(self) -> Tuple[int, ...]: The size, in bytes, of each numpy array. """ if self._array_file_sizes is None: - self._array_file_sizes = tuple(self.map(lambda item: get_file_size(item[0]))) + self._array_file_sizes = tuple(self.map(lambda path, _: get_file_size(path))) return self._array_file_sizes @property @@ -242,7 +242,7 @@ def _warmup_clients(self): def map( self, - func: Callable[[Tuple[PathOrStr, int]], T], + func: Callable[[PathOrStr, int], T], *, max_workers: Optional[int] = None, method: Literal["threads", "processes"] = "threads", @@ -251,7 +251,7 @@ def map( """ Call a function on each path in the dataset, returning a list of the results, in order. - :param func: The function to map to the paths. + :param func: The function to map to the paths and their indices. :param max_workers: The number of workers threads/processes. Set to 0 to execute synchronously in the main thread/process. :param method: Whether to use multi-threading or multi-processing. @@ -261,7 +261,7 @@ def map( paths = _paths or self.paths if max_workers == 0: - return [func((path, idx)) for idx, path in enumerate(paths)] + return [func(path, idx) for idx, path in enumerate(paths)] executor_class: Union[ Type[concurrent.futures.ThreadPoolExecutor], @@ -276,7 +276,7 @@ def map( raise ValueError(method) with executor_class(max_workers=max_workers) as executor: - futures = [executor.submit(func, (path, idx)) for idx, path in enumerate(paths)] + futures = [executor.submit(func, path, idx) for idx, path in enumerate(paths)] return [future.result() for future in futures] @@ -484,9 +484,8 @@ def _read_chunk_from_array(self, path: PathOrStr, index: int) -> torch.Tensor: ) def _get_file_size_and_length( - self, item: Tuple[PathOrStr, int], dtype: Optional[NumpyUIntTypes] = None + self, path: PathOrStr, idx: int, dtype: Optional[NumpyUIntTypes] = None ) -> Tuple[int, int]: - path, _ = item dtype = dtype or self.dtype item_size = dtype(0).itemsize file_size = get_file_size(path) @@ -525,7 +524,6 @@ def __init__( include_instance_metadata: Optional[bool] = None, generate_doc_lengths: bool = False, max_target_sequence_length: Optional[int] = None, - bust_index_cache: bool = False, ): if max_target_sequence_length is not None and ( max_target_sequence_length < sequence_length @@ -565,7 +563,6 @@ def __init__( self._lengths_dtype: Optional[NumpyUIntTypes] = None self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None self._path_offset_index = path_offset_index - self._bust_index_cache = bust_index_cache self._seed = seed def prepare(self): @@ -629,11 +626,11 @@ def _write_document_indices(self): ) def _get_file_size_and_length( - self, item: Tuple[PathOrStr, int], dtype: Optional[NumpyUIntTypes] = None + self, path: PathOrStr, idx: int, dtype: Optional[NumpyUIntTypes] = None ) -> Tuple[int, int]: dtype = dtype or self.dtype item_size = dtype(0).itemsize - file_size = self._get_size_from_offset_index(item) + file_size = self._get_size_from_offset_index((path, idx)) if ( self.max_target_sequence_length is None or self.max_target_sequence_length == self.sequence_length @@ -692,7 +689,7 @@ def offsets(self) -> Tuple[Tuple[int, int], ...]: if self._array_instance_offsets is None: item_size = self.indices_dtype(0).itemsize num_instances_per_path = self.map( - lambda item: get_file_size(self._get_instance_indices_path(item[0])) + lambda path, _: get_file_size(self._get_instance_indices_path(path)) // (item_size * 2) ) array_instance_offsets = [] @@ -1485,10 +1482,6 @@ class NumpyDatasetConfig(Config): """ The type of dataset. """ - bust_index_cache: bool = False - """ - Whether or not to bust the index cache. - """ source_mixture_config: Optional[SourceMixtureDatasetConfig] = None """ The source mixture dataset config. @@ -1707,7 +1700,6 @@ def build(self) -> NumpyDatasetBase: include_instance_metadata=self.include_instance_metadata, generate_doc_lengths=self.generate_doc_lengths, path_offset_index=mixture.to_index(), - bust_index_cache=self.bust_index_cache, ) else: dataset = NumpyFSLDataset( diff --git a/src/olmo_core/data/source_mixture.py b/src/olmo_core/data/source_mixture.py index b19294bb..c5858070 100644 --- a/src/olmo_core/data/source_mixture.py +++ b/src/olmo_core/data/source_mixture.py @@ -27,12 +27,31 @@ @dataclass class SourceMixtureConfig(Config): + """ + A configuration class for building a source mixture. + """ + source_name: str + """ + The name of the source. + """ target_ratio: float + """ + The target ratio of the source in the mixture. + """ paths: List[PathOrStr] - # 1.0 will result in a maximum of 1 repitition of the source data per epoch + """ + A list of paths to the source data. + """ max_repetition_ratio: float = 1.0 + """ + The maximum ratio of repetitions of the source data to include in the mixture. + This can be used to upsample the source data by setting the repetition ratio > 1. + """ max_source_fraction: float = 1.0 + """ + The maximum ratio of the source data to include in the mixture. + """ def validate(self): if self.target_ratio: @@ -43,6 +62,9 @@ def validate(self): if self.max_source_fraction < self.target_ratio: raise OLMoConfigurationError("max_source_fraction must be >= target_ratio") + if self.max_repetition_ratio < 1: + raise OLMoConfigurationError("max_repetition_ratio must be >= 1") + if not self.paths: raise OLMoConfigurationError("paths must not be empty") @@ -57,8 +79,17 @@ class SourceTokenDetails: """ config: SourceMixtureConfig + """ + The configuration object associated with the source. + """ population: int + """ + The total number of tokens available for the source. + """ num_selected: int + """ + The number of tokens to select for the source. + """ def for_table(self, max_tokens: int) -> Dict: return { @@ -82,7 +113,13 @@ class SourcePathTokens: @dataclass class SourceMixtureOutcome: name: str + """ + The name of the source. + """ path_tokens: List[SourcePathTokens] + """ + A list of paths and the associated token counts. + """ @dataclass @@ -92,7 +129,13 @@ class SourceMixtureDataset: """ seed: int + """ + The seed used to generate the dataset. + """ sources: List[SourceMixtureOutcome] + """ + A list of sources and the associated paths and token counts. + """ def to_index(self) -> Dict[Tuple[str, int], int]: """ @@ -122,11 +165,29 @@ class SourceMixtureDatasetConfig(Config): """ max_tokens: int + """ + The maximum number of tokens to include in the dataset. + """ source_configs: List[SourceMixtureConfig] + """ + A list of source configurations. + """ sequence_length: int + """ + The instance sequence length of the dataset. + """ dtype: NumpyDatasetDType + """ + The data type of the dataset. + """ processes: int = 1 + """ + The number of processes to use for counting tokens in parallel. + """ seed: int = 42 + """ + The seed used to generate the dataset. + """ def validate(self): if self.max_tokens <= 0: diff --git a/src/test/data/fixtures.py b/src/test/data/fixtures.py index a3f99c12..4fcaa84d 100644 --- a/src/test/data/fixtures.py +++ b/src/test/data/fixtures.py @@ -59,7 +59,6 @@ def get_fsl_mixture( source_mixture_config=mixture_config, sequence_length=sequence_length, tokenizer=tokenizer, - bust_index_cache=True, include_instance_metadata=False, ).build() ds.prepare() diff --git a/src/test/data/numpy_dataset_test.py b/src/test/data/numpy_dataset_test.py index 5d3a7d1e..e8c49952 100644 --- a/src/test/data/numpy_dataset_test.py +++ b/src/test/data/numpy_dataset_test.py @@ -110,7 +110,6 @@ def test_numpy_fsl_mixture_dataset(tmp_path: Path): source_mixture_config=mixture_config, sequence_length=sequence_length, tokenizer=tokenizer, - bust_index_cache=True, include_instance_metadata=False, ).build() ds.prepare() @@ -169,7 +168,6 @@ def test_numpy_fsl_mixture_dataset_with_repetition(tmp_path: Path): source_mixture_config=mixture_config, sequence_length=sequence_length, tokenizer=tokenizer, - bust_index_cache=True, include_instance_metadata=False, ).build() ds.prepare()