Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mapping new ladder to old ladder #146

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
95c0c55
make duration multiplier configurable
AkshitaB Jan 24, 2025
28d5e21
update changelog
AkshitaB Jan 24, 2025
49db66d
add to __all__
AkshitaB Jan 24, 2025
ee2be4c
fix command
AkshitaB Jan 24, 2025
fcf102a
change data parallel type
AkshitaB Jan 24, 2025
70dc6da
hsdp
AkshitaB Jan 28, 2025
56fc563
add duration to name
AkshitaB Jan 28, 2025
285a5b9
fix bug in overriding
AkshitaB Jan 28, 2025
232f217
use actual num params
AkshitaB Jan 28, 2025
55c9abf
Merge branch 'main' into akshitab/ladder_xC
AkshitaB Jan 28, 2025
5ca1e7f
fix
AkshitaB Jan 28, 2025
e15448c
remove extra files
AkshitaB Jan 28, 2025
65fab16
add zloss
AkshitaB Jan 29, 2025
5ae6342
fix mock batch
AkshitaB Jan 29, 2025
3fa28a8
loss settings: fused=True, compile=False
AkshitaB Jan 29, 2025
de38c25
Merge branch 'main' into akshitab/ladder_xC
AkshitaB Jan 29, 2025
829f6fc
not fused
AkshitaB Jan 29, 2025
faf0de5
reduce microbatch size
AkshitaB Jan 29, 2025
896fa54
reduce mbz further
AkshitaB Jan 29, 2025
a10c5e2
reset mbz
AkshitaB Jan 29, 2025
4785aaf
fix model params
AkshitaB Feb 5, 2025
8d9f535
Port over instance filtering from OLMo codebase
epwalsh Feb 6, 2025
2a34982
changelog
epwalsh Feb 6, 2025
6650a52
record percentage masked
epwalsh Feb 6, 2025
77b192b
include count from rank 0 for comparison
epwalsh Feb 6, 2025
269a95f
add to configs
epwalsh Feb 6, 2025
bfa53da
Merge branch 'epwalsh/instance-filter' into akshitab/ladder_xC
AkshitaB Feb 6, 2025
b55f599
add instance filtering
AkshitaB Feb 6, 2025
a617ae2
use loss computation from old trainer, for debugging
AkshitaB Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `SkipStepAdamW` optimizer.
- The trainer can load model-only checkpoints now.
- Added the option to throttle checkpoint uploads to one rank from each node at a time.
- Added `RunDuration` in `model_ladder` to configure training durations in terms of Chinchilla multipliers.

### Changed

Expand Down
20 changes: 12 additions & 8 deletions src/olmo_core/internal/model_ladder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from olmo_core.data import NumpyDataLoaderConfig, NumpyDatasetConfig
from olmo_core.distributed.utils import get_local_rank
from olmo_core.launch.beaker import BeakerLaunchConfig
from olmo_core.model_ladder import ModelLadder, ModelSize
from olmo_core.model_ladder import ModelLadder, ModelSize, RunDuration
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import OptimConfig
from olmo_core.train import (
Expand Down Expand Up @@ -97,6 +97,7 @@ def build_config(
ladder: ModelLadder,
script: str,
size: ModelSize,
run_duration: RunDuration,
cmd: SubCmd,
cluster: str,
overrides: List[str],
Expand All @@ -118,7 +119,9 @@ def build_config(
optim = ladder.get_optim_config(size=size)
dataset = ladder.get_dataset_config()
data_loader = ladder.get_data_loader_config(size=size)
trainer = ladder.get_trainer_config(size=size, gpu_type=gpu_type, dp_world_size=dp_world_size)
trainer = ladder.get_trainer_config(
size=size, run_duration=run_duration, gpu_type=gpu_type, dp_world_size=dp_world_size
)

return LadderRunConfig(
launch=launch,
Expand All @@ -133,7 +136,7 @@ def build_config(

def main(ladder_builder: Callable[[str], ModelLadder]):
usage = f"""
[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]SIZE CLUSTER[/] [i][OVERRIDES...][/]
[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]{'|'.join(SubCmd)}[/] [i b]SIZE RUN_DURATION CLUSTER[/] [i][OVERRIDES...][/]

[b]Subcommands[/]
[b magenta]launch:[/] Launch the script on Beaker with the [b magenta]train[/] subcommand.
Expand All @@ -142,16 +145,17 @@ def main(ladder_builder: Callable[[str], ModelLadder]):
[b magenta]dry_run:[/] Pretty print the config to run and exit.

[b]Examples[/]
$ [i]python {sys.argv[0]} {SubCmd.launch} 1B ai2/pluto-cirrascale --launch.num_nodes=2[/]
$ [i]python {sys.argv[0]} {SubCmd.launch} 1B Cx1 ai2/pluto-cirrascale --launch.num_nodes=2[/]
""".strip()

try:
script, cmd, size, cluster, overrides = (
script, cmd, size, run_duration, cluster, overrides = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to make this a required parameter and not just part of the config, with a default like 2xC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we usually run them as a cell in the grid of {model_sizes} x {chinchilla multipliers}, so it's convenient.

sys.argv[0],
SubCmd(sys.argv[1]),
ModelSize(sys.argv[2]),
sys.argv[3],
sys.argv[4:],
RunDuration(sys.argv[3]),
sys.argv[4],
sys.argv[5:],
)
except (IndexError, ValueError):
import rich
Expand All @@ -166,7 +170,7 @@ def main(ladder_builder: Callable[[str], ModelLadder]):
ladder.merge(overrides, prefix="ladder")

# Build run config.
config = build_config(ladder, script, size, cmd, cluster, overrides)
config = build_config(ladder, script, size, run_duration, cmd, cluster, overrides)
config.ladder.validate()

# Run the cmd.
Expand Down
42 changes: 39 additions & 3 deletions src/olmo_core/model_ladder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,39 @@ def num_params(self) -> int:
raise NotImplementedError(self)


class RunDuration(StrEnum):
"""
An enumeration of the standard training durations for the ladder, in terms of Chinchilla multipliers.
"""

Cx0_5 = "0.5xC"
"""
Multiplier of 0.5.
"""

Cx1 = "1xC"
"""
Multiplier of 1.
"""
Cx2 = "2xC"
"""
Multiplier of 2.
"""
Cx5 = "5xC"
"""
Multiplier of 5.
"""

Cx10 = "10xC"
"""
Multiplier of 10.
"""

@property
def multiplier(self) -> float:
return float(self.split("xC")[0])


@beta_feature
@dataclass
class ModelLadder(Config, metaclass=ABCMeta):
Expand Down Expand Up @@ -236,18 +269,21 @@ def get_global_batch_size(self, *, size: ModelSize) -> int:

return self.sequence_length * global_batch_size

def get_duration(self, size: ModelSize) -> Duration:
def get_duration(
self, size: ModelSize, run_duration: RunDuration = RunDuration.Cx2
) -> Duration:
"""
Get the duration to train for given the model size. Defaults to 2 x Chinchilla optimal.

:param size: The target model size.
"""
return Duration.tokens(2 * 20 * size.num_params)
return Duration.tokens(int(run_duration.multiplier * 20) * size.num_params)

def get_trainer_config(
self,
*,
size: ModelSize,
run_duration: RunDuration,
gpu_type: str,
dp_world_size: int,
) -> TrainerConfig:
Expand Down Expand Up @@ -315,7 +351,7 @@ def get_trainer_config(
metrics_collect_interval=10,
cancel_check_interval=1,
compile_loss=True,
max_duration=self.get_duration(size),
max_duration=self.get_duration(size, run_duration),
)
.with_callback(
"lr_scheduler", SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=2000))
Expand Down
Loading