Skip to content

Commit 25b4aed

Browse files
Merge branch 'feature/compress' into dkorzekwa/compress_tutorial
Signed-off-by: Daniel Korzekwa <[email protected]>
2 parents 6e1d910 + 1c12fd8 commit 25b4aed

File tree

7 files changed

+127
-103
lines changed

7 files changed

+127
-103
lines changed

examples/compress/configs/llama-3_1-8B_pruneffn_memory/pruning/ffn_pruning.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ activation_hooks_kwargs:
88
target_layer: "mlp.down_proj"
99
layer_input_descriptors_path:
1010

11-
intermediate_size_list: [256] # teacher_intermediate_size is 14336
11+
intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
1212
mlp_init_mode: "PruneByActivationsLog"

modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR
9696
9797
The output of this step will be used by mnt.search() to perform the NAS search.
9898
"""
99+
100+
# NativeDdpRuntime must be initialized/closed from outside of this function, so we are
101+
# NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it.
99102
runtime = NativeDdpRuntime(
100103
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
101104
)
@@ -199,6 +202,8 @@ def default_state_dict(self) -> SearchStateDict:
199202
return {}
200203

201204
def run_search(self) -> None:
205+
# NativeDdpRuntime must be initialized/closed from outside of this function, so we are
206+
# NOT calling runtime.cleanup() here. TODO: Not optimal - redesign it.
202207
runtime = NativeDdpRuntime(
203208
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
204209
)
@@ -220,10 +225,12 @@ def run_search(self) -> None:
220225
"Compress Progress 5/8: building replacement library and subblock statistics (single-gpu)"
221226
)
222227
)
223-
build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
228+
229+
build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
224230
runtime.wait_for_everyone()
225231

226232
# Calc_one_block_scores (distributed processing)
233+
227234
print(timestamped("Compress Progress 6/8: calculating one block scores (multi-gpu)"))
228235
scoring.launch_scoring(hydra_cfg, runtime)
229236

tests/experimental/torch/_compress/compress_test_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,68 @@
1919

2020
import torch
2121
from datasets import Dataset, DatasetDict
22+
from puzzle_tools.hydra_utils import register_hydra_resolvers
2223
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase
2324

2425

26+
def setup_test_model_and_data(
27+
project_root_path: Path,
28+
tmp_path: Path,
29+
rank: int,
30+
runtime,
31+
) -> tuple[
32+
Path,
33+
Path,
34+
Path,
35+
Path,
36+
str,
37+
]:
38+
"""
39+
Setup the test model and data for the compress NAS search.
40+
41+
Args:
42+
project_root_path (Path): the root path of the project
43+
tmp_path (Path): the temporary path to use for the test
44+
rank (int): the rank of the process
45+
runtime: the runtime to use for the test
46+
47+
Returns:
48+
tuple[Path, Path, Path, Path, str]:
49+
the puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name
50+
"""
51+
52+
# Register Hydra custom resolvers (needed for config resolution)
53+
register_hydra_resolvers()
54+
55+
# The inputs for the nas.convert() step.
56+
#
57+
puzzle_dir = tmp_path
58+
llama_checkpoint_path = puzzle_dir / "input_model/llama"
59+
dataset_path = puzzle_dir / "dummy_dataset"
60+
hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs"
61+
hydra_config_name = "Llama-3_1-8B"
62+
63+
if rank == 0:
64+
# Setup puzzle_dir and dataset
65+
setup_puzzle_dir(puzzle_dir)
66+
save_dummy_dataset(dataset_path)
67+
68+
# Create a small Llama model
69+
tokenizer = create_tokenizer(project_root_path)
70+
create_and_save_small_llama_model(
71+
llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer
72+
)
73+
runtime.wait_for_everyone()
74+
75+
return (
76+
puzzle_dir,
77+
llama_checkpoint_path,
78+
dataset_path,
79+
hydra_config_dir,
80+
hydra_config_name,
81+
)
82+
83+
2584
def create_and_save_small_llama_model(
2685
output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase
2786
):

tests/experimental/torch/_compress/nas/plugins/test_nas_convert.py

Lines changed: 25 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,7 @@
2020

2121
import torch
2222
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
23-
from experimental.torch._compress.compress_test_utils import (
24-
create_and_save_small_llama_model,
25-
create_tokenizer,
26-
save_dummy_dataset,
27-
setup_puzzle_dir,
28-
)
29-
from puzzle_tools.hydra_utils import register_hydra_resolvers
23+
from experimental.torch._compress.compress_test_utils import setup_test_model_and_data
3024

3125
import modelopt.torch.nas as mtn
3226
from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel
@@ -51,7 +45,30 @@ def _test_nas_convert_multiprocess_job(
5145
with NativeDdpRuntime(
5246
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
5347
) as runtime:
54-
converted_model, puzzle_dir = run_nas_convert(project_root_path, tmp_path, rank, runtime)
48+
# Setup the test model and data.
49+
puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = (
50+
setup_test_model_and_data(project_root_path, tmp_path, rank, runtime)
51+
)
52+
53+
#
54+
# Run the mnt.convert() step
55+
#
56+
input_model = CompressModel()
57+
mtn.convert(
58+
input_model,
59+
mode=[
60+
(
61+
"compress",
62+
{
63+
"puzzle_dir": str(puzzle_dir),
64+
"input_model_path": str(llama_checkpoint_path),
65+
"hydra_config_dir": str(hydra_config_dir),
66+
"hydra_config_name": hydra_config_name,
67+
"dataset_path": str(dataset_path),
68+
},
69+
)
70+
],
71+
)
5572

5673
#
5774
# Check assertions
@@ -70,54 +87,3 @@ def _test_nas_convert_multiprocess_job(
7087
runtime.wait_for_everyone()
7188

7289
print("PYTEST SUMMARY: test_nas_convert() test has finished successfully")
73-
74-
75-
def run_nas_convert(
76-
project_root_path: Path,
77-
tmp_path: Path,
78-
rank: int,
79-
runtime,
80-
):
81-
# Register Hydra custom resolvers (needed for config resolution)
82-
register_hydra_resolvers()
83-
84-
# The inputs for the nas.convert() step.
85-
#
86-
puzzle_dir = tmp_path
87-
llama_checkpoint_path = puzzle_dir / "input_model/llama"
88-
dataset_path = puzzle_dir / "dummy_dataset"
89-
hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs"
90-
hydra_config_name = "Llama-3_1-8B"
91-
92-
if rank == 0:
93-
# Setup puzzle_dir and dataset
94-
setup_puzzle_dir(puzzle_dir)
95-
save_dummy_dataset(dataset_path)
96-
97-
# Create a small Llama model
98-
tokenizer = create_tokenizer(project_root_path)
99-
create_and_save_small_llama_model(
100-
llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer
101-
)
102-
runtime.wait_for_everyone()
103-
104-
# Run the mnt.convert() step
105-
#
106-
input_model = CompressModel()
107-
converted_model = mtn.convert(
108-
input_model,
109-
mode=[
110-
(
111-
"compress",
112-
{
113-
"puzzle_dir": str(puzzle_dir),
114-
"input_model_path": str(llama_checkpoint_path),
115-
"hydra_config_dir": str(hydra_config_dir),
116-
"hydra_config_name": hydra_config_name,
117-
"dataset_path": str(dataset_path),
118-
},
119-
)
120-
],
121-
)
122-
123-
return converted_model, puzzle_dir

tests/experimental/torch/_compress/nas/plugins/test_nas_search.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323

2424
import torch
2525
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
26-
from experimental.torch._compress.nas.plugins.test_nas_convert import run_nas_convert
26+
from experimental.torch._compress.compress_test_utils import setup_test_model_and_data
2727

2828
import modelopt.torch.nas as mtn
29+
from modelopt.torch._compress.nas.plugins.compress_nas_plugin import CompressModel
2930
from modelopt.torch._compress.runtime import NativeDdpRuntime
3031

3132

@@ -43,7 +44,30 @@ def _test_nas_search_multiprocess_job(
4344
with NativeDdpRuntime(
4445
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
4546
) as runtime:
46-
converted_model, puzzle_dir = run_nas_convert(project_root_path, tmp_path, rank, runtime)
47+
# Setup the test model and data.
48+
puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = (
49+
setup_test_model_and_data(project_root_path, tmp_path, rank, runtime)
50+
)
51+
52+
#
53+
# Run the mnt.convert() step
54+
#
55+
input_model = CompressModel()
56+
converted_model = mtn.convert(
57+
input_model,
58+
mode=[
59+
(
60+
"compress",
61+
{
62+
"puzzle_dir": str(puzzle_dir),
63+
"input_model_path": str(llama_checkpoint_path),
64+
"hydra_config_dir": str(hydra_config_dir),
65+
"hydra_config_name": hydra_config_name,
66+
"dataset_path": str(dataset_path),
67+
},
68+
)
69+
],
70+
)
4771

4872
#
4973
# Run the mnt.search() step

tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ activation_hooks_kwargs:
88
target_layer: "mlp.down_proj"
99
layer_input_descriptors_path:
1010

11-
intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
11+
intermediate_size_list: [256] # teacher_intermediate_size is 14336
1212
mlp_init_mode: "PruneByActivationsLog"

tests/experimental/torch/_compress/test_compress.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,7 @@
2020

2121
import torch
2222
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
23-
from experimental.torch._compress.compress_test_utils import (
24-
create_and_save_small_llama_model,
25-
create_tokenizer,
26-
save_dummy_dataset,
27-
setup_puzzle_dir,
28-
)
29-
from puzzle_tools.hydra_utils import register_hydra_resolvers
23+
from experimental.torch._compress.compress_test_utils import setup_test_model_and_data
3024

3125
from modelopt.torch._compress import compress
3226
from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import (
@@ -63,42 +57,16 @@ def test_compress(project_root_path: Path, tmp_path: Path):
6357

6458

6559
def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, rank: int, size: int):
66-
register_hydra_resolvers()
67-
68-
#
69-
# The inputs for the compress() algorihm.
70-
#
71-
puzzle_dir = tmp_path
72-
dataset_path = puzzle_dir / "dummy_dataset"
73-
hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs"
74-
hydra_config_name = "Llama-3_1-8B"
75-
7660
with NativeDdpRuntime(
7761
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
7862
) as runtime:
79-
#
80-
# Test setup
81-
#
82-
if rank == 0:
83-
# Setup puzzle_dir and dataset
84-
setup_puzzle_dir(puzzle_dir)
85-
save_dummy_dataset(dataset_path)
86-
87-
#
88-
# Step 1: Create and save a teacher model to compress
89-
# This mimics the normal pipeline where we start with a Llama model
90-
#
91-
92-
# Create a small Llama model (not DeciLM) to match the normal conversion pipeline
93-
tokenizer = create_tokenizer(project_root_path)
94-
# TODO: change it to "ckpts/llama" once the conversion script is fixed
95-
# Currently, the build replacement library step will fail with such a path.
96-
llama_checkpoint_path = puzzle_dir / "input_model/llama"
97-
create_and_save_small_llama_model(
98-
llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer
99-
)
63+
# Setup the test model and data.
64+
puzzle_dir, llama_checkpoint_path, dataset_path, hydra_config_dir, hydra_config_name = (
65+
setup_test_model_and_data(project_root_path, tmp_path, rank, runtime)
66+
)
10067

101-
# Use the full conversion pipeline (matches normal usage)
68+
# Convert the Llama model to DeciLM model.
69+
if rank == 0:
10270
convert_llama3_to_decilm(
10371
input_dir=llama_checkpoint_path,
10472
output_dir=puzzle_dir / "ckpts/teacher",

0 commit comments

Comments
 (0)