diff --git a/CHANGELOG.md b/CHANGELOG.md index bb67d4e93..277e654f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `instance_filter_config` field to `NumpyDatasetConfig`. + ## [v1.8.0](https://github.com/allenai/OLMo-core/releases/tag/v1.8.0) - 2025-01-29 ### Added @@ -24,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for logging rich Table objects as text in source mixture datasets. - Added `unshard_strategy` parameter to `unshard_checkpoint()` function in `olmo_core.distributed.checkpoint`. - Added function `load_keys()` to `olmo_core.distributed.checkpoint`. +- Added `RunDuration` in `model_ladder` to configure training durations in terms of Chinchilla multipliers. ### Changed diff --git a/src/olmo_core/data/__init__.py b/src/olmo_core/data/__init__.py index b710100e1..22b7893d1 100644 --- a/src/olmo_core/data/__init__.py +++ b/src/olmo_core/data/__init__.py @@ -22,6 +22,7 @@ ) from .mixes import DataMix, DataMixBase from .numpy_dataset import ( + InstanceFilterConfig, NumpyDatasetBase, NumpyDatasetConfig, NumpyFSLDataset, @@ -50,6 +51,7 @@ "VSLGrowLinearCurriculum", "NumpyDatasetConfig", "NumpyDatasetType", + "InstanceFilterConfig", "VSLCurriculumType", "VSLCurriculumConfig", "NumpyDatasetDType", diff --git a/src/olmo_core/data/numpy_dataset.py b/src/olmo_core/data/numpy_dataset.py index 933861552..4ef0b917a 100644 --- a/src/olmo_core/data/numpy_dataset.py +++ b/src/olmo_core/data/numpy_dataset.py @@ -43,6 +43,7 @@ bucket_documents, chunk_array, divide_into_buckets, + find_periodic_sequences, get_doc_lengths_from_indices, get_document_lengths, get_rng, @@ -305,6 +306,25 @@ def __getitem__(self, index: int) -> Dict[str, Any]: """ raise NotImplementedError + def _validate_instance( + self, input_ids: torch.Tensor, instance_filter_config: InstanceFilterConfig + ) -> bool: + for m in find_periodic_sequences( + input_ids.numpy(), + max_period=instance_filter_config.repetition_max_period, + min_period=instance_filter_config.repetition_min_period, + ): + if m.times >= instance_filter_config.repetition_max_count: + return False + return True + + +@dataclass +class InstanceFilterConfig(Config): + repetition_max_period: int = 13 + repetition_min_period: int = 1 + repetition_max_count: int = 32 + class NumpyFSLDataset(NumpyDatasetBase, Dataset[Dict[str, Any]]): """ @@ -352,6 +372,7 @@ def __init__( include_instance_metadata: Optional[bool] = None, generate_doc_lengths: bool = False, max_target_sequence_length: Optional[int] = None, + instance_filter_config: Optional[InstanceFilterConfig] = None, ): if max_target_sequence_length is not None and ( max_target_sequence_length < sequence_length @@ -386,6 +407,7 @@ def __init__( self._num_instances: Optional[int] = None self._include_instance_metadata = include_instance_metadata self._generate_doc_lengths = generate_doc_lengths + self.instance_filter_config = instance_filter_config @property def num_tokens(self) -> int: @@ -450,6 +472,9 @@ def __getitem__(self, index: int) -> Dict[str, Any]: input_ids = self._read_chunk_from_array(self.paths[array_index], array_local_index) out: Dict[str, Any] = {"input_ids": input_ids} + if self.instance_filter_config is not None: + out["instance_mask"] = self._validate_instance(input_ids, self.instance_filter_config) + if self._include_instance_metadata: metadata = self._metadata[array_index] out["metadata"] = deepcopy(metadata) @@ -525,6 +550,7 @@ def __init__( include_instance_metadata: Optional[bool] = None, generate_doc_lengths: bool = False, max_target_sequence_length: Optional[int] = None, + instance_filter_config: Optional[InstanceFilterConfig] = None, ): if max_target_sequence_length is not None and ( max_target_sequence_length < sequence_length @@ -565,6 +591,7 @@ def __init__( self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None self._path_offset_index = path_offset_index self._seed = seed + self.instance_filter_config = instance_filter_config @property def indices_dtype( @@ -692,6 +719,7 @@ def __init__( dtype: NumpyUIntTypes = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, + instance_filter_config: Optional[InstanceFilterConfig] = None, ): super().__init__( *paths, @@ -702,6 +730,7 @@ def __init__( dtype=dtype, metadata=metadata, include_instance_metadata=include_instance_metadata, + instance_filter_config=instance_filter_config, ) self._array_instance_offsets: Optional[Tuple[Tuple[int, int], ...]] = None @@ -1122,6 +1151,7 @@ def __init__( dtype: NumpyUIntTypes = np.uint16, metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, include_instance_metadata: Optional[bool] = None, + instance_filter_config: Optional[InstanceFilterConfig] = None, ): if math.log(max_sequence_length, 2) % 1 != 0: raise OLMoConfigurationError("'max_sequence_length' must be a power of 2") @@ -1161,6 +1191,7 @@ def __init__( self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None self._lengths_dtype: Optional[NumpyUIntTypes] = None self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None + self.instance_filter_config = instance_filter_config @property def fingerprint_version(self) -> str: @@ -1251,6 +1282,9 @@ def __getitem__(self, index: int) -> Dict[str, Any]: input_ids = self._read_chunk_from_array(self.paths[array_index], array_local_index) out: Dict[str, Any] = {"input_ids": input_ids} + if self.instance_filter_config is not None: + out["instance_mask"] = self._validate_instance(input_ids, self.instance_filter_config) + if self._include_instance_metadata: metadata = self._metadata[array_index] out["metadata"] = deepcopy(metadata) @@ -1570,6 +1604,7 @@ class NumpyDatasetConfig(Config): You can save a lot of time and disk space by setting this to a common directory across all of you runs. """ + instance_filter_config: Optional[InstanceFilterConfig] = None def validate(self): if self.name in (NumpyDatasetType.fsl, NumpyDatasetType.padded_fsl): @@ -1724,6 +1759,7 @@ def build(self) -> NumpyDatasetBase: include_instance_metadata=self.include_instance_metadata, generate_doc_lengths=self.generate_doc_lengths, path_offset_index=mixture.to_index(), + instance_filter_config=self.instance_filter_config, ) else: dataset = NumpyFSLDataset( @@ -1737,6 +1773,7 @@ def build(self) -> NumpyDatasetBase: metadata=metadata, include_instance_metadata=self.include_instance_metadata, generate_doc_lengths=self.generate_doc_lengths, + instance_filter_config=self.instance_filter_config, ) elif self.name == NumpyDatasetType.padded_fsl: if self.sequence_length is None: @@ -1776,6 +1813,7 @@ def build(self) -> NumpyDatasetBase: dtype=self.get_dtype(), metadata=metadata, include_instance_metadata=self.include_instance_metadata, + instance_filter_config=self.instance_filter_config, ) elif self.name == NumpyDatasetType.vsl: if self.max_sequence_length is None: @@ -1801,6 +1839,7 @@ def build(self) -> NumpyDatasetBase: dtype=self.get_dtype(), metadata=metadata, include_instance_metadata=self.include_instance_metadata, + instance_filter_config=self.instance_filter_config, ) else: raise NotImplementedError(self.name) diff --git a/src/olmo_core/data/utils.py b/src/olmo_core/data/utils.py index dd44d3afd..4a69908a7 100644 --- a/src/olmo_core/data/utils.py +++ b/src/olmo_core/data/utils.py @@ -10,6 +10,7 @@ Generator, Iterable, List, + NamedTuple, Optional, Sequence, Tuple, @@ -468,3 +469,118 @@ def get_labels(batch: Dict[str, Any], label_ignore_index: int = -100) -> torch.T if instance_mask is not None: labels.masked_fill_(~instance_mask.unsqueeze(-1), value=label_ignore_index) return labels[..., 1:].contiguous() + + +def find_end_first_consecutive_true(arr: np.ndarray) -> int: + """Function to find the end position of the first consecutive sequence of True in an array.""" + if not arr[0]: + return 0 + + prog = np.cumsum(arr) + if prog[-1] == len(arr): + return len(arr) + + true_locs = np.where(prog[:-1:] == prog[1::])[0] + + return true_locs[0] + 1 + + +def find_start_last_consecutive_true(arr: np.ndarray) -> int: + """Function to find the start position of the last consecutive sequence of True in an array.""" + reverse = find_end_first_consecutive_true(arr[::-1]) + return len(arr) - reverse if reverse > 0 else -1 + + +def group_consecutive_values(arr: np.ndarray, stepsize: int = 1) -> List[np.ndarray]: + """Function to group consecutive values in an array.""" + return np.split(arr, np.where(np.diff(arr) != stepsize)[0] + 1) + + +class RepetitionTuple(NamedTuple): + """Tuple to store information about a periodic sequence.""" + + start: int + end: int + period: int + times: int + + +def find_periodic_sequences( + arr: np.ndarray, max_period: int, min_period: int = 1, mask_value: int = -1 +) -> Generator[RepetitionTuple, None, None]: + """Function to find periodic sequences in an array. + + This function sweeps through the array and checks for sequences of length + [min_period, max_period] that repeat at least 3 times. To do so, it + reshape the array into a matrix with `period` columns and checks if each + row is equal to the previous row. Blocks of repeating rows indicates repeating + sequences. + + Because there's no guarantee that the sequences start at the beginning of each + row, it can only detect sequences that repeat at least 3 times. To account + for the fact that sequences may not start at the beginning of each row (or + end at the end of each row), we check the end of the previous row and the + start of the next row to determine the actual start and end positions of the + sequence. + + Args: + arr (np.ndarray): The array to search for periodic sequences. + max_period (int): The maximum period to check for. + min_period (int, optional): The minimum period to check for. Defaults to 1. + mask_value (int, optional): The value to use to pad the array. Defaults to -1. + """ + # make sure the mask_value is not in the array + if (arr == mask_value).sum() > 0: + raise ValueError("`mask_value` is in the array") + + # no since we can only detect sequences that repeat at least 3 times, + # there is no point in checking for periods greater than 1/3 of the length + max_period = min(max_period, len(arr) // 3) + + for period in range(min_period, max_period + 1): + # pad the array so that it can be reshaped into a matrix matching the period + padded_arr = np.pad(arr, (0, period - (len(arr) % period)), constant_values=mask_value) + shaped_arr = padded_arr.reshape(-1, period) + + # find rows that are equal to the previous row; these are the possibly-periodic sequences + is_equal_to_prev_row = shaped_arr == np.roll(shaped_arr, shift=1, axis=0) + rows_with_period, *_ = np.where(is_equal_to_prev_row.all(axis=1)) + + # no sequences found with this period + if len(rows_with_period) == 0: + continue + + # this finds the start and end positions of the sequences with period `period` + where_true_consecutive = group_consecutive_values(rows_with_period) + + for sequence in where_true_consecutive: + start_row = sequence[0] + end_row = sequence[-1] + + # we check if any value at the end of the previous row is True, e.g.: + # [[False, False, True, True] + # [True, True, True, True]] + # (in the case above, start offset is 2). If so, we subtract that from the + # period to get the actual start offset. + start_offset = find_start_last_consecutive_true(is_equal_to_prev_row[start_row - 1]) + start_offset = period - start_offset if start_offset > 0 else 0 + + # same idea as above, we want to compute offset. Only difference is that + # `find_end_first_consecutive_true` already returns the offset, so we don't + # need to subtract from the period. + end_offset = find_end_first_consecutive_true(is_equal_to_prev_row[end_row + 1]) + + # because we are always comparing with preceding row in + # `is_equal_to_prev_row`, we need to subtract 1 from the row number + start_pos = (start_row - 1) * period - start_offset + + # note that the end position is exclusive + end_pos = ((end_row + 1) * period) + end_offset + + out = RepetitionTuple( + start=start_pos, end=end_pos, period=period, times=(end_pos - start_pos) // period + ) + if out.times > 2: + # cannot accurately determine the period of a sequence that repeats + # less than 3 times with this algorithm + yield out diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 9c87a8889..29c0ef7f8 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -9,6 +9,7 @@ from olmo_core.config import Config, StrEnum from olmo_core.data import ( DataMix, + InstanceFilterConfig, NumpyDataLoaderConfig, NumpyDatasetConfig, NumpyDatasetType, @@ -170,6 +171,11 @@ def build_common_components( name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False ), work_dir=get_work_dir(root_dir), + instance_filter_config=InstanceFilterConfig( + repetition_max_period=13, + repetition_min_period=1, + repetition_max_count=32, + ), ) data_loader_config = NumpyDataLoaderConfig( diff --git a/src/olmo_core/internal/model_ladder.py b/src/olmo_core/internal/model_ladder.py index 96b697670..97acfbb03 100644 --- a/src/olmo_core/internal/model_ladder.py +++ b/src/olmo_core/internal/model_ladder.py @@ -9,7 +9,7 @@ from olmo_core.data import NumpyDataLoaderConfig, NumpyDatasetConfig from olmo_core.distributed.utils import get_local_rank from olmo_core.launch.beaker import BeakerLaunchConfig -from olmo_core.model_ladder import ModelLadder, ModelSize +from olmo_core.model_ladder import ModelLadder, ModelSize, RunDuration from olmo_core.nn.transformer import TransformerConfig from olmo_core.optim import OptimConfig from olmo_core.train import ( @@ -97,6 +97,7 @@ def build_config( ladder: ModelLadder, script: str, size: ModelSize, + run_duration: RunDuration, cmd: SubCmd, cluster: str, overrides: List[str], @@ -105,9 +106,9 @@ def build_config( root_dir = get_root_dir(cluster) launch = build_launch_config( - name=f"{ladder.name}-{size}", + name=f"{ladder.name}-{size}-{run_duration}", root_dir=root_dir, - cmd=[script, SubCmd.train, size, cluster, *overrides], + cmd=[script, SubCmd.train, size, run_duration, cluster, *overrides], cluster=cluster, ).merge(overrides, strict=False) @@ -115,10 +116,12 @@ def build_config( gpu_type = get_gpu_type(cluster) model = ladder.get_model_config(size=size) - optim = ladder.get_optim_config(size=size) + optim = ladder.get_optim_config() dataset = ladder.get_dataset_config() - data_loader = ladder.get_data_loader_config(size=size) - trainer = ladder.get_trainer_config(size=size, gpu_type=gpu_type, dp_world_size=dp_world_size) + data_loader = ladder.get_data_loader_config() + trainer = ladder.get_trainer_config( + size=size, run_duration=run_duration, gpu_type=gpu_type, dp_world_size=dp_world_size + ) return LadderRunConfig( launch=launch, @@ -133,7 +136,7 @@ def build_config( def main(ladder_builder: Callable[[str], ModelLadder]): usage = f""" -[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]SIZE CLUSTER[/] [i][OVERRIDES...][/] +[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]SIZE RUN_DURATION CLUSTER[/] [i][OVERRIDES...][/] [b]Subcommands[/] [b magenta]launch:[/] Launch the script on Beaker with the [b magenta]train[/] subcommand. @@ -142,16 +145,17 @@ def main(ladder_builder: Callable[[str], ModelLadder]): [b magenta]dry_run:[/] Pretty print the config to run and exit. [b]Examples[/] -$ [i]python {sys.argv[0]} {SubCmd.launch} 1B ai2/pluto-cirrascale --launch.num_nodes=2[/] +$ [i]python {sys.argv[0]} {SubCmd.launch} 1B 1xC ai2/pluto-cirrascale --launch.num_nodes=2[/] """.strip() try: - script, cmd, size, cluster, overrides = ( + script, cmd, size, run_duration, cluster, overrides = ( sys.argv[0], SubCmd(sys.argv[1]), ModelSize(sys.argv[2]), - sys.argv[3], - sys.argv[4:], + RunDuration(sys.argv[3]), + sys.argv[4], + sys.argv[5:], ) except (IndexError, ValueError): import rich @@ -163,10 +167,10 @@ def main(ladder_builder: Callable[[str], ModelLadder]): # Build ladder config. ladder = ladder_builder(get_root_dir(cluster)) - ladder.merge(overrides, prefix="ladder") + ladder = ladder.merge(overrides, prefix="ladder") # Build run config. - config = build_config(ladder, script, size, cmd, cluster, overrides) + config = build_config(ladder, script, size, run_duration, cmd, cluster, overrides) config.ladder.validate() # Run the cmd. diff --git a/src/olmo_core/model_ladder.py b/src/olmo_core/model_ladder.py index e767e5643..9f8a94a8e 100644 --- a/src/olmo_core/model_ladder.py +++ b/src/olmo_core/model_ladder.py @@ -6,6 +6,8 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass, field +from olmo_core.data.numpy_dataset import InstanceFilterConfig + from .config import Config, StrEnum from .data import ( DataMix, @@ -33,7 +35,7 @@ WandBCallback, ) -__all__ = ["ModelSize", "ModelLadder"] +__all__ = ["ModelSize", "ModelLadder", "RunDuration"] log = logging.getLogger(__name__) @@ -89,6 +91,39 @@ def num_params(self) -> int: raise NotImplementedError(self) +class RunDuration(StrEnum): + """ + An enumeration of the standard training durations for the ladder, in terms of Chinchilla multipliers. + """ + + Cx0_5 = "0.5xC" + """ + Multiplier of 0.5. + """ + + Cx1 = "1xC" + """ + Multiplier of 1. + """ + Cx2 = "2xC" + """ + Multiplier of 2. + """ + Cx5 = "5xC" + """ + Multiplier of 5. + """ + + Cx10 = "10xC" + """ + Multiplier of 10. + """ + + @property + def multiplier(self) -> float: + return float(self.split("xC")[0]) + + @beta_feature @dataclass class ModelLadder(Config, metaclass=ABCMeta): @@ -159,11 +194,18 @@ class ModelLadder(Config, metaclass=ABCMeta): The maximum data parallel world size that you intent to run with. This is used to set the batch size. """ - def get_save_folder(self, size: ModelSize) -> str: - return str(join_path(self.save_folder, f"checkpoints/{self.name}-{size}")) + @property + def model_size(self) -> int: + """ + The size of the model in terms of non-embedding parameters. + """ + return self._model_size + + def get_save_folder(self, size: ModelSize, run_duration: RunDuration) -> str: + return str(join_path(self.save_folder, f"checkpoints/{self.name}-{size}-{run_duration}")) @abstractmethod - def get_model_config(self, *, size: ModelSize) -> TransformerConfig: + def _get_model_config(self, *, size: ModelSize) -> TransformerConfig: """ Get the model config for a given model size. @@ -171,8 +213,18 @@ def get_model_config(self, *, size: ModelSize) -> TransformerConfig: """ raise NotImplementedError + def get_model_config(self, *, size: ModelSize) -> TransformerConfig: + """ + Get the model config for a given model size. + + :param size: The target model size. + """ + model_config = self._get_model_config(size=size) + self._model_size = model_config.num_non_embedding_params + return model_config + @abstractmethod - def get_optim_config(self, *, size: ModelSize) -> OptimConfig: + def get_optim_config(self) -> OptimConfig: """ Get the optimizer config for a given model size. @@ -192,16 +244,21 @@ def get_dataset_config(self) -> NumpyDatasetConfig: mix_base_dir=self.mix_base_dir, sequence_length=self.sequence_length, work_dir=self.work_dir, + instance_filter_config=InstanceFilterConfig( + repetition_max_period=13, + repetition_min_period=1, + repetition_max_count=32, + ), ) - def get_data_loader_config(self, *, size: ModelSize) -> NumpyDataLoaderConfig: + def get_data_loader_config(self) -> NumpyDataLoaderConfig: """ Get the data loader config. :param size: The target model size. """ return NumpyDataLoaderConfig( - global_batch_size=self.get_global_batch_size(size=size), + global_batch_size=self.get_global_batch_size(), seed=self.data_seed, num_workers=4, ) @@ -217,7 +274,7 @@ def get_rank_microbatch_size(self, *, size: ModelSize, gpu_type: str) -> int: """ raise NotImplementedError - def get_global_batch_size(self, *, size: ModelSize) -> int: + def get_global_batch_size(self) -> int: """ Get the global batch size in tokens for a given model size. @@ -228,7 +285,7 @@ def get_global_batch_size(self, *, size: ModelSize) -> int: assert self.sequence_length in {2048, 4096, 8192} seq_len_divisor = self.sequence_length // 2048 - global_batch_size = 160 * (size.num_params / 108000000) ** (2 / 3) + global_batch_size = 160 * (self.model_size / 108000000) ** (2 / 3) global_batch_size /= seq_len_divisor global_batch_size /= self.max_dp_world_size global_batch_size = round(global_batch_size) @@ -236,18 +293,19 @@ def get_global_batch_size(self, *, size: ModelSize) -> int: return self.sequence_length * global_batch_size - def get_duration(self, size: ModelSize) -> Duration: + def get_duration(self, run_duration: RunDuration = RunDuration.Cx2) -> Duration: """ Get the duration to train for given the model size. Defaults to 2 x Chinchilla optimal. :param size: The target model size. """ - return Duration.tokens(2 * 20 * size.num_params) + return Duration.tokens(int(run_duration.multiplier * 20) * self.model_size) def get_trainer_config( self, *, size: ModelSize, + run_duration: RunDuration, gpu_type: str, dp_world_size: int, ) -> TrainerConfig: @@ -275,14 +333,14 @@ def get_trainer_config( rank_mbz_instances = rank_mbz // self.sequence_length - global_bz = self.get_global_batch_size(size=size) + global_bz = self.get_global_batch_size() if global_bz % self.sequence_length != 0: raise OLMoConfigurationError( f"global batch size ({rank_mbz:,d} tokens) must be divisible " f"by the sequence length ({self.sequence_length:,d})" ) - global_bz_instances = self.get_global_batch_size(size=size) // self.sequence_length + global_bz_instances = self.get_global_batch_size() // self.sequence_length if global_bz_instances % (rank_mbz_instances * dp_world_size) != 0: new_rank_mbz_instances = global_bz_instances // dp_world_size @@ -310,15 +368,20 @@ def get_trainer_config( return ( TrainerConfig( - save_folder=self.get_save_folder(size), + save_folder=self.get_save_folder(size, run_duration), rank_microbatch_size=rank_mbz, metrics_collect_interval=10, cancel_check_interval=1, + z_loss_multiplier=1e-5, compile_loss=True, - max_duration=self.get_duration(size), + fused_loss=False, + max_duration=self.get_duration(run_duration), ) .with_callback( - "lr_scheduler", SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=2000)) + "lr_scheduler", + SchedulerCallback( + scheduler=CosWithWarmup(warmup_steps=round(self.model_size / global_bz)) + ), ) .with_callback("gpu_monitor", GPUMemoryMonitorCallback()) .with_callback("grad_clipper", GradClipperCallback(max_grad_norm=1.0)) @@ -359,7 +422,7 @@ def get_trainer_config( .with_callback( "comet", CometCallback( - name=f"{self.name}-{size}", + name=f"{self.name}-{size}-{run_duration}", workspace="ai2", project=self.project, enabled=True, @@ -369,7 +432,7 @@ def get_trainer_config( .with_callback( "wandb", WandBCallback( - name=f"{self.name}-{size}", + name=f"{self.name}-{size}-{run_duration}", entity="ai2", project=self.project, enabled=False, @@ -385,7 +448,13 @@ def validate(self): :raises OLMoConfigurationError: If the ladder has any issues. """ for size in ModelSize: - target_size = int(size[:-1]) + # validating to match old ladder sizes. + if size == ModelSize.size_1B: + target_size = 1.3 + elif size == ModelSize.size_3B: + target_size = 3.2 + else: + target_size = int(size[:-1]) if size.endswith("M"): target_size = target_size * 10**6 elif size.endswith("B"): @@ -403,9 +472,9 @@ def validate(self): f"too far from target size of {size}: {model_config}" ) - self.get_optim_config(size=size) + self.get_optim_config() self.get_rank_microbatch_size(size=size, gpu_type="H100") - bz_tokens = self.get_global_batch_size(size=size) + bz_tokens = self.get_global_batch_size() if bz_tokens % self.sequence_length != 0: raise OLMoConfigurationError( f"Batch size of {bz_tokens:,d} tokens for model size {size} " diff --git a/src/olmo_core/nn/functional/cross_entropy_loss.py b/src/olmo_core/nn/functional/cross_entropy_loss.py index 205a708b7..98f971c1a 100644 --- a/src/olmo_core/nn/functional/cross_entropy_loss.py +++ b/src/olmo_core/nn/functional/cross_entropy_loss.py @@ -6,7 +6,7 @@ __all__ = ["cross_entropy_loss", "fused_cross_entropy_loss"] -def cross_entropy_loss( +def new_cross_entropy_loss( logits: torch.Tensor, labels: torch.Tensor, *, @@ -50,6 +50,26 @@ def cross_entropy_loss( return loss, z_loss +def cross_entropy_loss( + logits, + labels, + ignore_index: int = -100, + reduction: str = "mean", + compute_z_loss: bool = False, + z_loss_multiplier: float = 1e-4, +): + loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction) + if not compute_z_loss: + return loss, None + z_squared = logits.logsumexp(-1).pow(2) + if reduction == "mean": + z_squared = (z_squared * (labels != ignore_index)).mean() + elif reduction == "sum": + z_squared = (z_squared * (labels != ignore_index)).sum() + z_loss = z_loss_multiplier * z_squared + return loss, z_loss + + _fused_cross_entropy_loss: Optional[Callable] = None try: diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 94e864ce4..5f76dbc28 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -366,7 +366,7 @@ def olmo2_190M(cls, vocab_size: int, **kwargs) -> "TransformerConfig": def olmo2_370M(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=1024, - hidden_size_multiplier=1.4, + hidden_size_multiplier=1.5, n_layers=kwargs.pop("n_layers", 16), n_heads=kwargs.pop("n_heads", 16), vocab_size=vocab_size, @@ -418,6 +418,7 @@ def olmo2_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": qk_norm=kwargs.pop("qk_norm", True), rope_theta=kwargs.pop("rope_theta", 500_000), layer_norm_eps=1e-6, + hidden_size_multiplier=1.5, **kwargs, ) @@ -425,7 +426,7 @@ def olmo2_1B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": def olmo2_3B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": return cls.llama_like( d_model=3328, - hidden_size_multiplier=1.4, + hidden_size_multiplier=1.5, n_layers=kwargs.pop("n_layers", 16), n_heads=kwargs.pop("n_heads", 16), vocab_size=vocab_size, diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 233d12995..48cb73f66 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -1172,7 +1172,10 @@ def _train_microbatch_context( def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False): # Record how many instances are going to be skipped (masked out). if (instance_mask := batch.get("instance_mask")) is not None and not dry_run: - self.record_metric("train/masked instances", (~instance_mask).sum(), ReduceType.sum) + masked = (~instance_mask).float() + self.record_metric("train/masked instances (%)", masked.mean(), ReduceType.mean) + # TODO: remove this before merging + self.record_metric("train/masked instances (rank 0 count)", masked.sum(), None) # Zero-gradients. self.optim.zero_grad(set_to_none=True) diff --git a/src/scripts/train/OLMo2-ladder.py b/src/scripts/train/OLMo2-ladder.py index b9b465f6b..9b21a73ce 100644 --- a/src/scripts/train/OLMo2-ladder.py +++ b/src/scripts/train/OLMo2-ladder.py @@ -35,21 +35,26 @@ class BaselineModelLadder(ModelLadder): ModelSize.size_1B: dict(n_layers=16), # need to scale down our actual 1B model } - def get_model_config(self, *, size: ModelSize) -> TransformerConfig: + def _get_model_config(self, *, size: ModelSize) -> TransformerConfig: + # if size in [ModelSize.size_7B, ModelSize.size_13B]: + # data_parallel_type = DataParallelType.fsdp + # else: + # data_parallel_type = DataParallelType.ddp + data_parallel_type = DataParallelType.hsdp return getattr(TransformerConfig, f"olmo2_{size}")( vocab_size=self.tokenizer.padded_vocab_size(), init_seed=self.init_seed, compile=True, dp_config=TransformerDataParallelConfig( - name=DataParallelType.hsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 + name=data_parallel_type, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 ), **self.MODEL_OVERRIDES.get(size, {}), ) - def get_optim_config(self, *, size: ModelSize) -> OptimConfig: + def get_optim_config(self) -> OptimConfig: # Calculate LR according to https://api.semanticscholar.org/CorpusID:270764838 assert self.sequence_length in {2048, 4096} - lr = 0.0047 * (size.num_params / 108000000) ** (-1 / 3) + lr = 0.0047 * (self.model_size / 108000000) ** (-1 / 3) if self.sequence_length == 4096: lr /= 4