Skip to content

Commit 1c12fd8

Browse files
modelopt nas search() implementation for the compress algorithm (#490)
Signed-off-by: Daniel Korzekwa <[email protected]> Signed-off-by: Keval Morabia <[email protected]> Co-authored-by: Keval Morabia <[email protected]>
1 parent 002b8b5 commit 1c12fd8

File tree

5 files changed

+239
-73
lines changed

5 files changed

+239
-73
lines changed

modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@
2020
import datetime
2121
from pathlib import Path
2222

23+
import build_library_and_stats
24+
import mip_and_realize_models
2325
import pruning_ckpts
2426
import score_pruning_activations
27+
import scoring
2528
import torch
26-
from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm
2729
from torch import nn
2830

31+
from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import (
32+
convert_llama3_to_decilm,
33+
)
2934
from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir
3035
from modelopt.torch._compress.runtime import NativeDdpRuntime
3136
from modelopt.torch.nas.conversion import NASModeRegistry
@@ -37,7 +42,7 @@
3742
ModeDescriptor,
3843
RestoreEntrypoint,
3944
)
40-
from modelopt.torch.opt.searcher import BaseSearcher
45+
from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict
4146

4247

4348
class CompressModel(nn.Module):
@@ -90,10 +95,19 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR
9095
9196
The output of this step will be used by mnt.search() to perform the NAS search.
9297
"""
98+
99+
# NativeDdpRuntime must be initialized/closed from outside of this function, so we are
100+
# NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it.
93101
runtime = NativeDdpRuntime(
94102
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
95103
)
96104

105+
# Required for mtn.search() to read NAS configuration
106+
model.hydra_config_dir = config.hydra_config_dir
107+
model.hydra_config_name = config.hydra_config_name
108+
model.puzzle_dir = config.puzzle_dir
109+
model.dataset_path = config.dataset_path
110+
97111
# Load hydra config
98112
hydra_cfg = initialize_hydra_config_for_dir(
99113
config_dir=config.hydra_config_dir,
@@ -146,7 +160,8 @@ def config_class(self) -> type[ModeloptBaseConfig]:
146160
@property
147161
def search_algorithm(self) -> type[BaseSearcher]:
148162
"""Return the associated searcher implementation."""
149-
raise NotImplementedError("Compress mode does not have a search algorithm yet.")
163+
164+
return CompressSearcher
150165

151166
@property
152167
def convert(self) -> ConvertEntrypoint:
@@ -165,3 +180,40 @@ def export_mode(self) -> str | None:
165180
for the compress algorithm.
166181
"""
167182
return "export_nas"
183+
184+
185+
class CompressSearcher(BaseSearcher):
186+
"""Runs NAS search for the Compress mode."""
187+
188+
@property
189+
def default_state_dict(self) -> SearchStateDict:
190+
"""Not needed for the compress mode as we are not saving any model state"""
191+
return {}
192+
193+
def run_search(self) -> None:
194+
# NativeDdpRuntime must be initialized/closed from outside of this function, so we are
195+
# NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it.
196+
runtime = NativeDdpRuntime(
197+
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
198+
)
199+
200+
# Load hydra config
201+
hydra_cfg = initialize_hydra_config_for_dir(
202+
config_dir=self.model.hydra_config_dir,
203+
config_name=self.model.hydra_config_name,
204+
overrides=[
205+
f"puzzle_dir={self.model.puzzle_dir}",
206+
f"dataset_path={self.model.dataset_path}",
207+
],
208+
)
209+
210+
# Build_library_and_stats (single process)
211+
if runtime.global_rank == 0:
212+
build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
213+
runtime.wait_for_everyone()
214+
215+
# Calc_one_block_scores (distributed processing)
216+
scoring.launch_scoring(hydra_cfg, runtime)
217+
218+
# mip_and_realize_models (distributed processing)
219+
mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime)

tests/experimental/torch/_compress/compress_test_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,68 @@
1919

2020
import torch
2121
from datasets import Dataset, DatasetDict
22+
from puzzle_tools.hydra_utils import register_hydra_resolvers
2223
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase
2324

2425

26+
def setup_test_model_and_data(
27+
project_root_path: Path,
28+
tmp_path: Path,
29+
rank: int,
30+
runtime,
31+
) -> tuple[
32+
Path,
33+
Path,
34+
Path,
35+
Path,
36+
str,
37+
]:
38+
"""
39+
Setup the test model and data for the compress NAS search.
40+
41+
Args:
42+
project_root_path (Path): the root path of the project
43+
tmp_path (Path): the temporary path to use for the test
44+
rank (int): the rank of the process
45+
runtime: the runtime to use for the test
46+
47+
Returns:
48+
tuple[Path, Path, Path, Path, str]:
49+
the puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name
50+
"""
51+
52+
# Register Hydra custom resolvers (needed for config resolution)
53+
register_hydra_resolvers()
54+
55+
# The inputs for the nas.convert() step.
56+
#
57+
puzzle_dir = tmp_path
58+
llama_checkpoint_path = puzzle_dir / "input_model/llama"
59+
dataset_path = puzzle_dir / "dummy_dataset"
60+
hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs"
61+
hydra_config_name = "Llama-3_1-8B"
62+
63+
if rank == 0:
64+
# Setup puzzle_dir and dataset
65+
setup_puzzle_dir(puzzle_dir)
66+
save_dummy_dataset(dataset_path)
67+
68+
# Create a small Llama model
69+
tokenizer = create_tokenizer(project_root_path)
70+
create_and_save_small_llama_model(
71+
llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer
72+
)
73+
runtime.wait_for_everyone()
74+
75+
return (
76+
puzzle_dir,
77+
llama_checkpoint_path,
78+
dataset_path,
79+
hydra_config_dir,
80+
hydra_config_name,
81+
)
82+
83+
2584
def create_and_save_small_llama_model(
2685
output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase
2786
):

tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,7 @@
2020

2121
import torch
2222
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
23-
from experimental.torch._compress.compress_test_utils import (
24-
create_and_save_small_llama_model,
25-
create_tokenizer,
26-
save_dummy_dataset,
27-
setup_puzzle_dir,
28-
)
29-
from puzzle_tools.hydra_utils import register_hydra_resolvers
23+
from experimental.torch._compress.compress_test_utils import setup_test_model_and_data
3024

3125
import modelopt.torch.nas as mtn
3226
from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel
@@ -48,32 +42,13 @@ def test_nas_convert(project_root_path: Path, tmp_path: Path):
4842
def _test_nas_convert_multiprocess_job(
4943
project_root_path: Path, tmp_path: Path, rank: int, size: int
5044
):
51-
# Register Hydra custom resolvers (needed for config resolution)
52-
register_hydra_resolvers()
53-
54-
#
55-
# The inputs for the nas.convert() step.
56-
#
57-
puzzle_dir = tmp_path
58-
llama_checkpoint_path = puzzle_dir / "ckpts/llama"
59-
dataset_path = puzzle_dir / "dummy_dataset"
60-
hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs"
61-
hydra_config_name = "Llama-3_1-8B"
62-
6345
with NativeDdpRuntime(
6446
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
6547
) as runtime:
66-
if rank == 0:
67-
# Setup puzzle_dir and dataset
68-
setup_puzzle_dir(puzzle_dir)
69-
save_dummy_dataset(dataset_path)
70-
71-
# Create a small Llama model
72-
tokenizer = create_tokenizer(project_root_path)
73-
create_and_save_small_llama_model(
74-
llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer
75-
)
76-
runtime.wait_for_everyone()
48+
# Setup the test model and data.
49+
puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = (
50+
setup_test_model_and_data(project_root_path, tmp_path, rank, runtime)
51+
)
7752

7853
#
7954
# Run the mnt.convert() step
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
#
17+
# See tests/experimental/torch/_compress/test_compress.py for instructions on how to run this test
18+
# TODO: Remove those instructions once this test runs automatically on CI
19+
#
20+
import datetime
21+
from functools import partial
22+
from pathlib import Path
23+
24+
import torch
25+
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
26+
from experimental.torch._compress.compress_test_utils import setup_test_model_and_data
27+
28+
import modelopt.torch.nas as mtn
29+
from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel
30+
from modelopt.torch._compress.runtime import NativeDdpRuntime
31+
32+
33+
def test_nas_search(project_root_path: Path, tmp_path: Path):
34+
spawn_multiprocess_job(
35+
size=torch.cuda.device_count(),
36+
job=partial(_test_nas_search_multiprocess_job, project_root_path, tmp_path),
37+
backend="nccl",
38+
)
39+
40+
41+
def _test_nas_search_multiprocess_job(
42+
project_root_path: Path, tmp_path: Path, rank: int, size: int
43+
):
44+
with NativeDdpRuntime(
45+
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
46+
) as runtime:
47+
# Setup the test model and data.
48+
puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = (
49+
setup_test_model_and_data(project_root_path, tmp_path, rank, runtime)
50+
)
51+
52+
#
53+
# Run the mnt.convert() step
54+
#
55+
input_model = CompressModel()
56+
converted_model = mtn.convert(
57+
input_model,
58+
mode=[
59+
(
60+
"compress",
61+
{
62+
"puzzle_dir": str(puzzle_dir),
63+
"input_model_path": str(llama_checkpoint_path),
64+
"hydra_config_dir": str(hydra_config_dir),
65+
"hydra_config_name": hydra_config_name,
66+
"dataset_path": str(dataset_path),
67+
},
68+
)
69+
],
70+
)
71+
72+
#
73+
# Run the mnt.search() step
74+
#
75+
mtn.search(
76+
converted_model,
77+
constraints={}, # this is not used as the search space is defined in the hydra config
78+
dummy_input=None, # Not used
79+
config={}, # this is not used as the search space is defined in the hydra config
80+
)
81+
82+
#
83+
# Check assertions for mtn.search() step
84+
#
85+
if rank == 0:
86+
# assertions for the build_library_and_stats step
87+
assert (puzzle_dir / "replacement_library.json").is_file()
88+
assert (puzzle_dir / "subblock_stats.json").is_file()
89+
90+
# assertions for the scoring step
91+
solution_0_filepath = (
92+
puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json"
93+
)
94+
95+
assert solution_0_filepath.exists()
96+
97+
# assertions for the mip_and_realize_models step
98+
solution_0_ckpt_config_path = (
99+
puzzle_dir
100+
/ "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json"
101+
)
102+
103+
assert solution_0_ckpt_config_path.exists()
104+
assert (
105+
puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json"
106+
).exists()
107+
108+
runtime.wait_for_everyone()
109+
110+
print("PYTEST SUMMARY: test_nas_search() test has finished successfully")

0 commit comments

Comments
 (0)