Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mapping new ladder to old ladder #146

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
95c0c55
make duration multiplier configurable
AkshitaB Jan 24, 2025
28d5e21
update changelog
AkshitaB Jan 24, 2025
49db66d
add to __all__
AkshitaB Jan 24, 2025
ee2be4c
fix command
AkshitaB Jan 24, 2025
fcf102a
change data parallel type
AkshitaB Jan 24, 2025
70dc6da
hsdp
AkshitaB Jan 28, 2025
56fc563
add duration to name
AkshitaB Jan 28, 2025
285a5b9
fix bug in overriding
AkshitaB Jan 28, 2025
232f217
use actual num params
AkshitaB Jan 28, 2025
55c9abf
Merge branch 'main' into akshitab/ladder_xC
AkshitaB Jan 28, 2025
5ca1e7f
fix
AkshitaB Jan 28, 2025
e15448c
remove extra files
AkshitaB Jan 28, 2025
65fab16
add zloss
AkshitaB Jan 29, 2025
5ae6342
fix mock batch
AkshitaB Jan 29, 2025
3fa28a8
loss settings: fused=True, compile=False
AkshitaB Jan 29, 2025
de38c25
Merge branch 'main' into akshitab/ladder_xC
AkshitaB Jan 29, 2025
829f6fc
not fused
AkshitaB Jan 29, 2025
faf0de5
reduce microbatch size
AkshitaB Jan 29, 2025
896fa54
reduce mbz further
AkshitaB Jan 29, 2025
a10c5e2
reset mbz
AkshitaB Jan 29, 2025
4785aaf
fix model params
AkshitaB Feb 5, 2025
8d9f535
Port over instance filtering from OLMo codebase
epwalsh Feb 6, 2025
2a34982
changelog
epwalsh Feb 6, 2025
6650a52
record percentage masked
epwalsh Feb 6, 2025
77b192b
include count from rank 0 for comparison
epwalsh Feb 6, 2025
269a95f
add to configs
epwalsh Feb 6, 2025
bfa53da
Merge branch 'epwalsh/instance-filter' into akshitab/ladder_xC
AkshitaB Feb 6, 2025
b55f599
add instance filtering
AkshitaB Feb 6, 2025
a617ae2
use loss computation from old trainer, for debugging
AkshitaB Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `instance_filter_config` field to `NumpyDatasetConfig`.

## [v1.8.0](https://github.com/allenai/OLMo-core/releases/tag/v1.8.0) - 2025-01-29

### Added
Expand All @@ -24,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for logging rich Table objects as text in source mixture datasets.
- Added `unshard_strategy` parameter to `unshard_checkpoint()` function in `olmo_core.distributed.checkpoint`.
- Added function `load_keys()` to `olmo_core.distributed.checkpoint`.
- Added `RunDuration` in `model_ladder` to configure training durations in terms of Chinchilla multipliers.

### Changed

Expand Down
2 changes: 2 additions & 0 deletions src/olmo_core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from .mixes import DataMix, DataMixBase
from .numpy_dataset import (
InstanceFilterConfig,
NumpyDatasetBase,
NumpyDatasetConfig,
NumpyFSLDataset,
Expand Down Expand Up @@ -50,6 +51,7 @@
"VSLGrowLinearCurriculum",
"NumpyDatasetConfig",
"NumpyDatasetType",
"InstanceFilterConfig",
"VSLCurriculumType",
"VSLCurriculumConfig",
"NumpyDatasetDType",
Expand Down
39 changes: 39 additions & 0 deletions src/olmo_core/data/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
bucket_documents,
chunk_array,
divide_into_buckets,
find_periodic_sequences,
get_doc_lengths_from_indices,
get_document_lengths,
get_rng,
Expand Down Expand Up @@ -305,6 +306,25 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
"""
raise NotImplementedError

def _validate_instance(
self, input_ids: torch.Tensor, instance_filter_config: InstanceFilterConfig
) -> bool:
for m in find_periodic_sequences(
input_ids.numpy(),
max_period=instance_filter_config.repetition_max_period,
min_period=instance_filter_config.repetition_min_period,
):
if m.times >= instance_filter_config.repetition_max_count:
return False
return True


@dataclass
class InstanceFilterConfig(Config):
repetition_max_period: int = 13
repetition_min_period: int = 1
repetition_max_count: int = 32


class NumpyFSLDataset(NumpyDatasetBase, Dataset[Dict[str, Any]]):
"""
Expand Down Expand Up @@ -352,6 +372,7 @@ def __init__(
include_instance_metadata: Optional[bool] = None,
generate_doc_lengths: bool = False,
max_target_sequence_length: Optional[int] = None,
instance_filter_config: Optional[InstanceFilterConfig] = None,
):
if max_target_sequence_length is not None and (
max_target_sequence_length < sequence_length
Expand Down Expand Up @@ -386,6 +407,7 @@ def __init__(
self._num_instances: Optional[int] = None
self._include_instance_metadata = include_instance_metadata
self._generate_doc_lengths = generate_doc_lengths
self.instance_filter_config = instance_filter_config

@property
def num_tokens(self) -> int:
Expand Down Expand Up @@ -450,6 +472,9 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
input_ids = self._read_chunk_from_array(self.paths[array_index], array_local_index)
out: Dict[str, Any] = {"input_ids": input_ids}

if self.instance_filter_config is not None:
out["instance_mask"] = self._validate_instance(input_ids, self.instance_filter_config)

if self._include_instance_metadata:
metadata = self._metadata[array_index]
out["metadata"] = deepcopy(metadata)
Expand Down Expand Up @@ -525,6 +550,7 @@ def __init__(
include_instance_metadata: Optional[bool] = None,
generate_doc_lengths: bool = False,
max_target_sequence_length: Optional[int] = None,
instance_filter_config: Optional[InstanceFilterConfig] = None,
):
if max_target_sequence_length is not None and (
max_target_sequence_length < sequence_length
Expand Down Expand Up @@ -565,6 +591,7 @@ def __init__(
self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None
self._path_offset_index = path_offset_index
self._seed = seed
self.instance_filter_config = instance_filter_config

@property
def indices_dtype(
Expand Down Expand Up @@ -692,6 +719,7 @@ def __init__(
dtype: NumpyUIntTypes = np.uint16,
metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
include_instance_metadata: Optional[bool] = None,
instance_filter_config: Optional[InstanceFilterConfig] = None,
):
super().__init__(
*paths,
Expand All @@ -702,6 +730,7 @@ def __init__(
dtype=dtype,
metadata=metadata,
include_instance_metadata=include_instance_metadata,
instance_filter_config=instance_filter_config,
)
self._array_instance_offsets: Optional[Tuple[Tuple[int, int], ...]] = None

Expand Down Expand Up @@ -1122,6 +1151,7 @@ def __init__(
dtype: NumpyUIntTypes = np.uint16,
metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
include_instance_metadata: Optional[bool] = None,
instance_filter_config: Optional[InstanceFilterConfig] = None,
):
if math.log(max_sequence_length, 2) % 1 != 0:
raise OLMoConfigurationError("'max_sequence_length' must be a power of 2")
Expand Down Expand Up @@ -1161,6 +1191,7 @@ def __init__(
self._array_offsets: Optional[Tuple[Tuple[int, int], ...]] = None
self._lengths_dtype: Optional[NumpyUIntTypes] = None
self._instances_per_bucket: Optional[Tuple[Tuple[int, int], ...]] = None
self.instance_filter_config = instance_filter_config

@property
def fingerprint_version(self) -> str:
Expand Down Expand Up @@ -1251,6 +1282,9 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
input_ids = self._read_chunk_from_array(self.paths[array_index], array_local_index)
out: Dict[str, Any] = {"input_ids": input_ids}

if self.instance_filter_config is not None:
out["instance_mask"] = self._validate_instance(input_ids, self.instance_filter_config)

if self._include_instance_metadata:
metadata = self._metadata[array_index]
out["metadata"] = deepcopy(metadata)
Expand Down Expand Up @@ -1570,6 +1604,7 @@ class NumpyDatasetConfig(Config):
You can save a lot of time and disk space by setting this to a common directory across
all of you runs.
"""
instance_filter_config: Optional[InstanceFilterConfig] = None

def validate(self):
if self.name in (NumpyDatasetType.fsl, NumpyDatasetType.padded_fsl):
Expand Down Expand Up @@ -1724,6 +1759,7 @@ def build(self) -> NumpyDatasetBase:
include_instance_metadata=self.include_instance_metadata,
generate_doc_lengths=self.generate_doc_lengths,
path_offset_index=mixture.to_index(),
instance_filter_config=self.instance_filter_config,
)
else:
dataset = NumpyFSLDataset(
Expand All @@ -1737,6 +1773,7 @@ def build(self) -> NumpyDatasetBase:
metadata=metadata,
include_instance_metadata=self.include_instance_metadata,
generate_doc_lengths=self.generate_doc_lengths,
instance_filter_config=self.instance_filter_config,
)
elif self.name == NumpyDatasetType.padded_fsl:
if self.sequence_length is None:
Expand Down Expand Up @@ -1776,6 +1813,7 @@ def build(self) -> NumpyDatasetBase:
dtype=self.get_dtype(),
metadata=metadata,
include_instance_metadata=self.include_instance_metadata,
instance_filter_config=self.instance_filter_config,
)
elif self.name == NumpyDatasetType.vsl:
if self.max_sequence_length is None:
Expand All @@ -1801,6 +1839,7 @@ def build(self) -> NumpyDatasetBase:
dtype=self.get_dtype(),
metadata=metadata,
include_instance_metadata=self.include_instance_metadata,
instance_filter_config=self.instance_filter_config,
)
else:
raise NotImplementedError(self.name)
Expand Down
116 changes: 116 additions & 0 deletions src/olmo_core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Generator,
Iterable,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -468,3 +469,118 @@ def get_labels(batch: Dict[str, Any], label_ignore_index: int = -100) -> torch.T
if instance_mask is not None:
labels.masked_fill_(~instance_mask.unsqueeze(-1), value=label_ignore_index)
return labels[..., 1:].contiguous()


def find_end_first_consecutive_true(arr: np.ndarray) -> int:
"""Function to find the end position of the first consecutive sequence of True in an array."""
if not arr[0]:
return 0

prog = np.cumsum(arr)
if prog[-1] == len(arr):
return len(arr)

true_locs = np.where(prog[:-1:] == prog[1::])[0]

return true_locs[0] + 1


def find_start_last_consecutive_true(arr: np.ndarray) -> int:
"""Function to find the start position of the last consecutive sequence of True in an array."""
reverse = find_end_first_consecutive_true(arr[::-1])
return len(arr) - reverse if reverse > 0 else -1


def group_consecutive_values(arr: np.ndarray, stepsize: int = 1) -> List[np.ndarray]:
"""Function to group consecutive values in an array."""
return np.split(arr, np.where(np.diff(arr) != stepsize)[0] + 1)


class RepetitionTuple(NamedTuple):
"""Tuple to store information about a periodic sequence."""

start: int
end: int
period: int
times: int


def find_periodic_sequences(
arr: np.ndarray, max_period: int, min_period: int = 1, mask_value: int = -1
) -> Generator[RepetitionTuple, None, None]:
"""Function to find periodic sequences in an array.

This function sweeps through the array and checks for sequences of length
[min_period, max_period] that repeat at least 3 times. To do so, it
reshape the array into a matrix with `period` columns and checks if each
row is equal to the previous row. Blocks of repeating rows indicates repeating
sequences.

Because there's no guarantee that the sequences start at the beginning of each
row, it can only detect sequences that repeat at least 3 times. To account
for the fact that sequences may not start at the beginning of each row (or
end at the end of each row), we check the end of the previous row and the
start of the next row to determine the actual start and end positions of the
sequence.

Args:
arr (np.ndarray): The array to search for periodic sequences.
max_period (int): The maximum period to check for.
min_period (int, optional): The minimum period to check for. Defaults to 1.
mask_value (int, optional): The value to use to pad the array. Defaults to -1.
"""
# make sure the mask_value is not in the array
if (arr == mask_value).sum() > 0:
raise ValueError("`mask_value` is in the array")

# no since we can only detect sequences that repeat at least 3 times,
# there is no point in checking for periods greater than 1/3 of the length
max_period = min(max_period, len(arr) // 3)

for period in range(min_period, max_period + 1):
# pad the array so that it can be reshaped into a matrix matching the period
padded_arr = np.pad(arr, (0, period - (len(arr) % period)), constant_values=mask_value)
shaped_arr = padded_arr.reshape(-1, period)

# find rows that are equal to the previous row; these are the possibly-periodic sequences
is_equal_to_prev_row = shaped_arr == np.roll(shaped_arr, shift=1, axis=0)
rows_with_period, *_ = np.where(is_equal_to_prev_row.all(axis=1))

# no sequences found with this period
if len(rows_with_period) == 0:
continue

# this finds the start and end positions of the sequences with period `period`
where_true_consecutive = group_consecutive_values(rows_with_period)

for sequence in where_true_consecutive:
start_row = sequence[0]
end_row = sequence[-1]

# we check if any value at the end of the previous row is True, e.g.:
# [[False, False, True, True]
# [True, True, True, True]]
# (in the case above, start offset is 2). If so, we subtract that from the
# period to get the actual start offset.
start_offset = find_start_last_consecutive_true(is_equal_to_prev_row[start_row - 1])
start_offset = period - start_offset if start_offset > 0 else 0

# same idea as above, we want to compute offset. Only difference is that
# `find_end_first_consecutive_true` already returns the offset, so we don't
# need to subtract from the period.
end_offset = find_end_first_consecutive_true(is_equal_to_prev_row[end_row + 1])

# because we are always comparing with preceding row in
# `is_equal_to_prev_row`, we need to subtract 1 from the row number
start_pos = (start_row - 1) * period - start_offset

# note that the end position is exclusive
end_pos = ((end_row + 1) * period) + end_offset

out = RepetitionTuple(
start=start_pos, end=end_pos, period=period, times=(end_pos - start_pos) // period
)
if out.times > 2:
# cannot accurately determine the period of a sequence that repeats
# less than 3 times with this algorithm
yield out
6 changes: 6 additions & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from olmo_core.config import Config, StrEnum
from olmo_core.data import (
DataMix,
InstanceFilterConfig,
NumpyDataLoaderConfig,
NumpyDatasetConfig,
NumpyDatasetType,
Expand Down Expand Up @@ -170,6 +171,11 @@ def build_common_components(
name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False
),
work_dir=get_work_dir(root_dir),
instance_filter_config=InstanceFilterConfig(
repetition_max_period=13,
repetition_min_period=1,
repetition_max_count=32,
),
)

data_loader_config = NumpyDataLoaderConfig(
Expand Down
Loading
Loading