Skip to content

Commit 50a580c

Browse files
danielkorzekwakevalmorabia97LianaMikael
authored
Compress tutorial (PoC) (#492)
## What does this PR do? Compress tutorial (PoC) + compress cli app. --------- Signed-off-by: Daniel Korzekwa <[email protected]> Signed-off-by: Keval Morabia <[email protected]> Signed-off-by: Liana Mikaelyan <[email protected]> Signed-off-by: Liana Mikaelyan <[email protected]> Co-authored-by: Keval Morabia <[email protected]> Co-authored-by: Liana Mikaelyan <[email protected]> Co-authored-by: Liana Mikaelyan <[email protected]>
1 parent f7d547f commit 50a580c

File tree

18 files changed

+879
-10
lines changed

18 files changed

+879
-10
lines changed

examples/compress/README.md

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Compress Algorithm Tutorial
2+
3+
This tutorial demonstrates how to compress large language models using the Compress algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146).
4+
This tutorial demonstrates how to compress large language models using the compress algorithm based on the [Puzzle paper](https://arxiv.org/abs/2411.19146).
5+
The goal of the algorithm it to find the most optimal modifications to MLP and attention layers of the model, resulting in a heterogeneous model architecture.
6+
The supported modifications are:
7+
8+
- `ffn_intermediate_size`: different FFN intermediate sizes
9+
- `attention op/noop`: complete removal of attention layers
10+
11+
To use the Puzzle algorithm effectively, we need to specify the target number of parameters and/or the memory. The final stage is based on Mixed-Integer Programming (MIP) algorithm to find the most optimal combination of layer modifications that satisfy the target requirements.
12+
13+
In this example, we compress the [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric.
14+
15+
## Environment
16+
17+
- Install TensorRT-Model-Optimizer in editable mode with the corresponding dependencies:
18+
19+
```bash
20+
pip install -e .[hf,compress]
21+
```
22+
23+
- For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use s single GPU.
24+
25+
## Compress the Model
26+
27+
1. Specify the `puzzle_dir`, `input_hf_model_path`, `dataset_path`, `intermediate_size_list`, and `target_memory` arguments in the [llama-3_1-8B_pruneffn_memory.yaml](./configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml) configuration file.
28+
29+
**_NOTE:_**
30+
How to choose `intermediate_size_list`?
31+
The list specifies the candidate FFN sizes that we wish to search over. It is recommended to choose several pruning sizes (e.g. 15%, 20%, 30% etc of the original). Note that the values must be hardware-friendly (divisible by a 256) to avoid issues with tensor operations in subsequent steps.
32+
33+
Let's first shoot for 32% GPU memory reduction setting `target_memory = 78_000` GiB. This means that the algorithm will choose the candidates with highest accuracy that also meet the specified requirements.
34+
35+
2. Download and prepare the [Nemotron-Post-Training-Dataset-v2](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2).
36+
37+
dataset split: "code", "math", "stem", "chat", excluding reasoning samples (2.62GB)
38+
39+
```bash
40+
python -m modelopt.torch._compress.dataset.prepare_dataset --dataset_name nvidia/Nemotron-Post-Training-Dataset-v2 --output_dir path/to/Nemotron-Post-Training-Dataset-v2
41+
```
42+
43+
3. Run the compression script.
44+
45+
```bash
46+
torchrun --nproc_per_node 2 examples/compress/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Compress Progress"
47+
```
48+
49+
This will save the full output to `log.txt` and display the following progress on screen:
50+
51+
```bash
52+
[2025-11-02 12:06:34][rank-0][main.py:71] Compress Progress 1/8: starting compression pipeline
53+
[2025-11-02 12:06:45][rank-0][compress_nas_plugin.py:123] Compress Progress 2/8: converting model from HF to DeciLM (single-gpu)
54+
[2025-11-02 12:07:07][rank-0][compress_nas_plugin.py:132] Compress Progress 3/8: scoring pruning activations (multi-gpu)
55+
[2025-11-02 12:11:36][rank-0][compress_nas_plugin.py:137] Compress Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)
56+
[2025-11-02 12:12:20][rank-0][compress_nas_plugin.py:217] Compress Progress 5/8: building replacement library and subblock statistics (single-gpu)
57+
[2025-11-02 12:12:21][rank-0][compress_nas_plugin.py:222] Compress Progress 6/8: calculating one block scores (multi-gpu)
58+
[2025-11-02 12:50:41][rank-0][compress_nas_plugin.py:226] Compress Progress 7/8: running MIP and realizing models (multi-gpu)
59+
[2025-11-02 12:52:34][rank-0][main.py:115] Compress Progress 8/8: compression pipeline completed (multi-gpu)
60+
```
61+
62+
Once the process is complete, the resulting network architecture will be recorded in `log.txt` for your review:
63+
64+
```bash
65+
...
66+
block_0: attention gqa_4 ffn intermediate_14336
67+
block_1: attention gqa_4 ffn intermediate_14336
68+
block_2: attention gqa_4 ffn intermediate_14336
69+
block_3: attention gqa_4 ffn intermediate_14336
70+
block_4: attention gqa_4 ffn intermediate_14336
71+
block_5: attention gqa_4 ffn intermediate_14336
72+
block_6: attention gqa_4 ffn intermediate_14336
73+
block_7: attention gqa_4 ffn intermediate_14336
74+
block_8: attention gqa_4 ffn intermediate_14336
75+
block_9: attention gqa_4 ffn intermediate_14336
76+
block_10: attention gqa_4 ffn intermediate_14336
77+
block_11: attention gqa_4 ffn intermediate_14336
78+
block_12: attention gqa_4 ffn intermediate_14336
79+
block_13: attention gqa_4 ffn intermediate_14336
80+
block_14: attention gqa_4 ffn intermediate_14336
81+
block_15: attention gqa_4 ffn intermediate_14336
82+
block_16: attention gqa_4 ffn intermediate_14336
83+
block_17: attention no_op ffn intermediate_14336
84+
block_18: attention no_op ffn intermediate_14336
85+
block_19: attention no_op ffn intermediate_14336
86+
block_20: attention no_op ffn intermediate_14336
87+
block_21: attention no_op ffn intermediate_14336
88+
block_22: attention no_op ffn intermediate_14336
89+
block_23: attention no_op ffn intermediate_14336
90+
block_24: attention no_op ffn intermediate_14336
91+
block_25: attention no_op ffn intermediate_14336
92+
block_26: attention no_op ffn intermediate_14336
93+
block_27: attention no_op ffn intermediate_14336
94+
block_28: attention no_op ffn intermediate_14336
95+
block_29: attention gqa_4 ffn intermediate_14336
96+
block_30: attention gqa_4 ffn intermediate_14336
97+
block_31: attention gqa_4 ffn intermediate_14336
98+
99+
[2025-11-02 04:53:11,332]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 75796.4140625, 'stats.ffn_num_params': 5637275648, 'stats.num_kv_heads': 160, 'stats.kv_cache_memory_mib': 61440.0, 'stats.ffn_memory_mib': 10752.25, 'stats.attention_memory_mib': 63040.15625, 'stats.attention_num_params': 838942720, 'stats.num_params': 7526895616, 'stats.has_attention': 20, 'stats.has_ffn': 32}
100+
...
101+
################################################################
102+
validate_model_and_extract_token_probs(model_name='teacher')
103+
################################################################
104+
...
105+
Average losses = {'lm_loss': 1.118250765837729, 'token_accuracy_top_1': 0.7331905364990234, 'token_accuracy_top_5': 0.9094219207763672, 'token_accuracy_top_10': 0.9423646926879883}
106+
...
107+
################################################################
108+
validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True)
109+
################################################################
110+
....
111+
Average losses = {'lm_loss': 1.7577573340386152, 'token_accuracy_top_1': 0.6225490570068359, 'token_accuracy_top_5': 0.846257209777832, 'token_accuracy_top_10': 0.8987817764282227}
112+
```
113+
114+
30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). Let's rerun MIP search aiming for 15% memory reduction.
115+
116+
## Re-run MIP Search with different constraints
117+
118+
If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag.
119+
This assumes pruning, replacement library building, NAS scoring, and subblock stats calculation have already been completed.
120+
121+
For example, let's set `target_memory: 96_000` in `llama-3_1-8B_pruneffn_memory.yaml`.
122+
123+
```bash
124+
torchrun --nproc_per_node 2 examples/compress/main.py --config path/to/llama-3_1-8B_pruneffn_memory.yaml --mip-only 2>&1 | tee ./log.txt | grep "Compress Progress"
125+
```
126+
127+
This will generate the following network architecture (see `log.txt`):
128+
129+
```bash
130+
block_0: attention gqa_4 ffn intermediate_14336
131+
block_1: attention gqa_4 ffn intermediate_14336
132+
block_2: attention gqa_4 ffn intermediate_14336
133+
block_3: attention gqa_4 ffn intermediate_14336
134+
block_4: attention gqa_4 ffn intermediate_14336
135+
block_5: attention gqa_4 ffn intermediate_14336
136+
block_6: attention gqa_4 ffn intermediate_14336
137+
block_7: attention gqa_4 ffn intermediate_14336
138+
block_8: attention gqa_4 ffn intermediate_14336
139+
block_9: attention gqa_4 ffn intermediate_14336
140+
block_10: attention gqa_4 ffn intermediate_14336
141+
block_11: attention gqa_4 ffn intermediate_14336
142+
block_12: attention gqa_4 ffn intermediate_14336
143+
block_13: attention gqa_4 ffn intermediate_14336
144+
block_14: attention gqa_4 ffn intermediate_14336
145+
block_15: attention gqa_4 ffn intermediate_14336
146+
block_16: attention gqa_4 ffn intermediate_14336
147+
block_17: attention gqa_4 ffn intermediate_14336
148+
block_18: attention no_op ffn intermediate_14336
149+
block_19: attention no_op ffn intermediate_14336
150+
block_20: attention no_op ffn intermediate_14336
151+
block_21: attention gqa_4 ffn intermediate_14336
152+
block_22: attention no_op ffn intermediate_14336
153+
block_23: attention no_op ffn intermediate_14336
154+
block_24: attention no_op ffn intermediate_14336
155+
block_25: attention gqa_4 ffn intermediate_14336
156+
block_26: attention gqa_4 ffn intermediate_14336
157+
block_27: attention gqa_4 ffn intermediate_14336
158+
block_28: attention gqa_4 ffn intermediate_14336
159+
block_29: attention gqa_4 ffn intermediate_14336
160+
block_30: attention gqa_4 ffn intermediate_14336
161+
block_31: attention gqa_4 ffn intermediate_14336
162+
163+
[2025-11-02 12:50:42,024]^[[92m[rank-0]^[[0m[run_puzzle.py:295] Total costs: {'stats.memory_mib': 94708.4609375, 'stats.has_ffn': 32, 'stats.ffn_memory_mib': 10752.25, 'stats.kv_cache_memory_mib': 79872.0, 'stats.attention_num_params': 1090625536, 'stats.ffn_num_params': 5637275648, 'stats.has_attention': 26, 'stats.num_params': 7778578432, 'stats.attention_memory_mib': 81952.203125, 'stats.num_kv_heads': 208}
164+
...
165+
################################################################
166+
validate_model_with_kl_div(model_name='solution_0', is_calc_kl_div=True)
167+
################################################################
168+
Average losses = {'lm_loss': 1.2425934937782586, 'token_accuracy_top_1': 0.703862190246582, 'token_accuracy_top_5': 0.8954982757568359, 'token_accuracy_top_10': 0.9336576461791992
169+
```
170+
171+
On the other hand, if you set `target_memory: 28_000`, you'll observe that the intermediate FFN sizes are significantly reduced in certain layers (see `log.txt` for details):
172+
173+
```bash
174+
block_5: attention no_op ffn intermediate_11520
175+
block_6: attention no_op ffn intermediate_14336
176+
block_7: attention no_op ffn intermediate_8704
177+
block_8: attention no_op ffn intermediate_14336
178+
block_9: attention no_op ffn intermediate_3072
179+
block_10: attention no_op ffn intermediate_11520
180+
block_11: attention no_op ffn intermediate_11520
181+
block_12: attention no_op ffn intermediate_11520
182+
block_13: attention no_op ffn intermediate_11520
183+
block_14: attention no_op ffn intermediate_3072
184+
```
185+
186+
## Evaluation
187+
188+
Once the model is ready, you can evaluate it using [Language Model Evaluation Harness](https://pypi.org/project/lm-eval/). For example, run the following to evaluate the model on [Massive Multitask Language Understanding](https://huggingface.co/datasets/cais/mmlu) benchmark.
189+
190+
```bash
191+
lm_eval --model hf \
192+
--model_args pretrained=path/to/model,dtype=bfloat16,trust_remote_code=true,parallelize=True \
193+
--tasks mmlu \
194+
--num_fewshot 5 \
195+
--batch_size 4
196+
```
197+
198+
## Advanced usage
199+
200+
Modify `path/to/Llama-3_1-8B yaml` file for advanced compression scenarios.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
defaults:
2+
- pruning: ffn_pruning
3+
- scoring: ../validate_solutions_defaults
4+
- realize_model: ../validate_solutions_defaults
5+
- bypass:
6+
- override hydra/hydra_logging: disabled
7+
- _self_
8+
9+
puzzle_dir: ???
10+
teacher_dir: ${puzzle_dir}/ckpts/teacher/
11+
replacement_library_path: ${puzzle_dir}/replacement_library.json
12+
dataset_path: ??? # path to v0.4_mini
13+
14+
skip_realize_model: false
15+
16+
build_replacement_library:
17+
add_ffn_no_ops: true
18+
add_attention_no_ops: true
19+
20+
calc_subblock_stats:
21+
batch_sizes: [64, 96, 128]
22+
prefill_seq_len: 4096
23+
generation_seq_len: 4096
24+
num_active_tokens_override: # Optional override for sequence lengths
25+
prefill_queue_size: 0
26+
allocate_prefill_query: false
27+
benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
28+
merge_with_existing_stats: false
29+
subblock_stats_filename: "subblock_stats.json"
30+
moe_stats_filename: "moe_stats.json"
31+
runtime_stats:
32+
backend: trt_torch
33+
34+
scoring:
35+
solutions_to_validate:
36+
skip_existing_solutions: true
37+
38+
replacement_library_path: ${replacement_library_path}
39+
solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
40+
teacher_dir: ${to_path:${teacher_dir}}
41+
output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation
42+
43+
eval_samples: 10 # default is 128
44+
micro_batch_size: 1
45+
seed: 42
46+
shuffle_seed: 444
47+
dataset_path: ${dataset_path}
48+
49+
mip:
50+
single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
51+
subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
52+
output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
53+
gathered_metrics_path:
54+
puzzle_profile:
55+
56+
# puzzle_profile:
57+
objective: metrics.cosine_embedding_loss_hidden_states
58+
bigger_is_better: false
59+
num_solutions: 1
60+
minimal_diversity: 2
61+
62+
subblock_stats_args:
63+
- batch_size: 96
64+
weights_dtype: torch.bfloat16
65+
activations_dtype: torch.bfloat16
66+
kv_cache_dtype: torch.bfloat16
67+
68+
report_additional_costs:
69+
- stats.memory_mib
70+
- stats.num_params
71+
- stats.num_kv_heads
72+
- stats.has_attention
73+
- stats.has_ffn
74+
- stats.kv_cache_memory_mib
75+
- stats.attention_memory_mib
76+
- stats.ffn_memory_mib
77+
- stats.ffn_num_params
78+
- stats.attention_num_params
79+
80+
human_constraints:
81+
target_memory: 78_000
82+
83+
mip_constraints:
84+
use_greedy_search: false
85+
is_multi_layer_puzzle: true
86+
metric_overrides:
87+
constrain_search_func:
88+
max_seconds_per_solution: 60
89+
90+
realize_model:
91+
teacher_dir: ${to_path:${teacher_dir}}
92+
tokenizer_name: ${to_path:${teacher_dir}}
93+
replacement_library_path: ${replacement_library_path}
94+
save_models: true
95+
solutions_path: # Filled dynamically
96+
97+
# Validate params
98+
skip_validation: false # To enable validation of the model solution set `skip_validation` as False
99+
eval_samples: 128
100+
micro_batch_size: 1
101+
seed: 42
102+
shuffle_seed: 444
103+
dataset_path: ${dataset_path}
104+
105+
nccl_timeout_minutes: ${timedelta_minutes:10}
106+
107+
# This section redirects Hydra outputs
108+
hydra:
109+
run:
110+
dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
defaults:
2+
- Llama-3_1-8B
3+
- _self_
4+
5+
# Input Hugging Face model to compress
6+
input_hf_model_path: /workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct
7+
8+
# Dataset path for pruning and NAS scoring
9+
dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2
10+
11+
# Working directory for compression outputs
12+
puzzle_dir: /workspace/puzzle_dir
13+
14+
# MIP memory constraint (in MiB)
15+
mip:
16+
human_constraints:
17+
target_memory: 96_000 # 96 GiB
18+
19+
# FFN intermediate sizes to search over (heterogeneous architecture)
20+
pruning:
21+
intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
defaults:
2+
- pruning_defaults
3+
4+
activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
5+
6+
activation_hooks_kwargs:
7+
method: independent_kv_head_contribution
8+
optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory
9+
target_layer: "self_attn.o_proj"
10+
layer_input_descriptors_path:
11+
12+
# n_heads_in_group: 4
13+
# num_attention_heads: 32 # num query heads
14+
# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group
15+
n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1]
16+
gqa_init_mode: "PruneKVHeads"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defaults:
2+
- pruning_defaults
3+
4+
activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
5+
6+
activation_hooks_kwargs:
7+
method: iterative
8+
target_layer: "mlp.down_proj"
9+
layer_input_descriptors_path:
10+
11+
intermediate_size_list: [3072, 5888, 8704, 11520] # teacher_intermediate_size is 14336
12+
mlp_init_mode: "PruneByActivationsLog"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
defaults:
2+
- pruning_defaults
3+
4+
activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id}
5+
6+
activation_hooks_kwargs:
7+
method: layer_norm_contribution
8+
target_layer: "layernorm"
9+
10+
# Hidden dimension pruning specific settings
11+
hidden_size_list: [3072, 2048] # Target hidden sizes to prune to
12+
hidden_size_init_mode: "PruneByChannelRanking"
13+
mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher
14+
gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher
15+
linear_init_mode: "FromTeacher"

0 commit comments

Comments
 (0)