Skip to content

Commit 16ffc96

Browse files
authored
Qualcomm AI Engine Direct - Enable Lookahead Decoding (#11437)
## Summary: - Add new eval_mode: lookahead - Add three arguments: ngram, window, gcap - Add lhd_token_generator ## Command ``` python3 examples/qualcomm/oss_scripts/llama/llama.py -b build-android --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --prompt "Once" --temperature 0 --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode lookahead --ptq 16a4w -m SM8650 -H ${host} -s ${device} -a ${artifacts} --max_seq_len 4096 --kv_updater smart_mask --prefill_ar_len 64 --ngram 3 --window 2 --gcap 2 ``` ## Test Results QNN SDK: 2.28 Device: SM8650 max_seq_len: 4096 ### Performance Improvement under different AR-N and different W/G/N Llama 3.2 3B ![image](https://github.com/user-attachments/assets/98365f9f-ccb9-49b0-a4ab-e51b9880efc3) Llama 3.2 1B ![image](https://github.com/user-attachments/assets/c21aba0a-ab2d-4f30-9fbe-cce439bd5f7e) Story Llama 110M ![image](https://github.com/user-attachments/assets/debdf888-5ece-400f-b3ae-892c06ef352a) ### Performance Improvement under different prompt ![image](https://github.com/user-attachments/assets/cd072fa5-1eda-4390-9748-882baab442e0) ## Reference - https://lmsys.org/blog/2023-11-21-lookahead-decoding/ - https://github.com/hao-ai-lab/LookaheadDecoding/tree/main/lade - https://github.com/ggml-org/llama.cpp/blob/master/examples/lookahead/lookahead.cpp cc: @haowhsu-quic
1 parent cf252a8 commit 16ffc96

File tree

13 files changed

+782
-69
lines changed

13 files changed

+782
-69
lines changed

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ list(
3636
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
3737
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
3838
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
39+
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.cpp
40+
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.h
3941
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
4042
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
4143
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
7-
3. LLAMA3.2 3B (WIP)
7+
3. LLAMA3.2 3B
88

99
We offer the following modes to execute the model:
1010

11-
KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
11+
- KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
1212

13-
Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
13+
- Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
1414
- AR-N model: The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use it to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor in hybrid mode.
1515
- Prompt processing with AR-N model:
1616
<figure>
@@ -19,6 +19,7 @@ Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache
1919
</figcaption>
2020
</figure>
2121

22+
- Lookahead Mode: Lookahead Mode introduces [lookahead decoding](https://arxiv.org/abs/2402.02057) and uses AR-N model to process prompt to enhance token generation speed. While decoding multiple tokens in a single step is infeasible, an LLM can generate multiple guess tokens in parallel. These guess tokens may fit into future parts of the generated sequence. The lookahead decoder generates and verifies these guess tokens, integrating them into the sequence if suitable. In some cases, it can obtain more than one token in a single step. Result is lossless.
2223

2324
## Instructions
2425
### Note
@@ -127,3 +128,14 @@ You can select the KV Cache update mechanism at runtime by setting the `KV_UPDAT
127128
```bash
128129
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
129130
```
131+
132+
You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters:
133+
- `--ngram` (N-gram size): Represents the size of the n-grams used in the lookahead process.
134+
- `--window` (window size): Determines how many future tokens the algorithm attempts to predict in each step.
135+
- `--gcap` (Verification candidates): Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.
136+
137+
For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057)
138+
139+
```bash
140+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
141+
```

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import getpass
1212
import json
1313
import logging
14+
import math
1415
import os
1516
import subprocess
1617
import sys
@@ -90,6 +91,12 @@
9091
logging.getLogger().setLevel(logging.INFO)
9192

9293

94+
def next_power_of_two(n):
95+
if n == 0:
96+
return 1
97+
return 2 ** math.ceil(math.log2(n))
98+
99+
93100
def smart_mask_updater(
94101
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
95102
):
@@ -531,6 +538,28 @@ def compile(args, pte_filename, tokenizer):
531538
use_i64_token=use_i64_token,
532539
)
533540
)
541+
elif args.model_mode == "lookahead":
542+
llama_instance_list.append(
543+
LlamaModel(
544+
kv_config,
545+
# To get better performance, we round up to the nearest power of 2.
546+
ar_len=next_power_of_two(
547+
(args.window + args.gcap) * (args.ngram - 1)
548+
),
549+
output_new_cache_only=True,
550+
output_cache=True,
551+
use_i64_token=use_i64_token,
552+
)
553+
)
554+
llama_instance_list.append(
555+
LlamaModel(
556+
prefill_config,
557+
ar_len=args.prefill_ar_len,
558+
output_new_cache_only=True,
559+
output_cache=True,
560+
use_i64_token=use_i64_token,
561+
)
562+
)
534563
else:
535564
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
536565

@@ -630,8 +659,8 @@ def permute(w, heads):
630659
tokenizer=tokenizer,
631660
custom_annotations=custom_annotations,
632661
)
633-
# If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
634-
if i == 0 and args.model_mode == "hybrid":
662+
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
663+
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
635664
output_indices = 0
636665
for node in llama_instance.llama_graph_module.graph.nodes:
637666
if node.op == "output":
@@ -673,7 +702,7 @@ def permute(w, heads):
673702
shared_buffer=args.shared_buffer,
674703
)
675704
quant_attrs = llama_instance_list[0].get_quant_attrs()
676-
elif args.model_mode == "hybrid":
705+
elif args.model_mode in ["hybrid", "lookahead"]:
677706
sample_inputs_list = [
678707
llama_instace.inputs for llama_instace in llama_instance_list
679708
]
@@ -759,6 +788,8 @@ def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
759788
eval_mode = 0
760789
elif args.model_mode == "hybrid":
761790
eval_mode = 1
791+
elif args.model_mode == "lookahead":
792+
eval_mode = 2
762793
else:
763794
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
764795

@@ -832,6 +863,9 @@ def post_process():
832863
"--output_path outputs/outputs.txt",
833864
f"--performance_output_path {performance_output_path}",
834865
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
866+
f"--window {args.window}",
867+
f"--gcap {args.gcap}",
868+
f"--ngram {args.ngram}",
835869
runner_args,
836870
]
837871
)
@@ -971,9 +1005,9 @@ def _build_parser():
9711005

9721006
parser.add_argument(
9731007
"--model_mode",
974-
help="Export and inference kv mode or hybrid mode",
1008+
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
9751009
default="kv",
976-
choices=["kv", "hybrid"],
1010+
choices=["kv", "hybrid", "lookahead"],
9771011
type=str,
9781012
)
9791013

@@ -986,7 +1020,7 @@ def _build_parser():
9861020

9871021
parser.add_argument(
9881022
"--prefill_ar_len",
989-
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.",
1023+
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid and lookahead mode.",
9901024
default=32,
9911025
type=int,
9921026
)
@@ -1007,6 +1041,27 @@ def _build_parser():
10071041
help="Fallback to cpu embedding operator and type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '4,32'.",
10081042
)
10091043

1044+
parser.add_argument(
1045+
"--ngram",
1046+
help="Represents the size of the n-grams used in the lookahead process.",
1047+
default=5,
1048+
type=int,
1049+
)
1050+
1051+
parser.add_argument(
1052+
"--window",
1053+
help="Determines how many future tokens the algorithm attempts to predict in each step.",
1054+
default=8,
1055+
type=int,
1056+
)
1057+
1058+
parser.add_argument(
1059+
"--gcap",
1060+
help="Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.",
1061+
default=8,
1062+
type=int,
1063+
)
1064+
10101065
parser.add_argument("-v", "--verbose", action="store_true")
10111066

10121067
return parser
@@ -1023,6 +1078,14 @@ def export_llama(args) -> None:
10231078
args.max_seq_len >= args.prefill_ar_len
10241079
), "Please ensure max_seq_len is >= prefill_ar_len"
10251080
pte_filename = "hybrid_llama_qnn"
1081+
elif args.model_mode == "lookahead":
1082+
assert (
1083+
args.max_seq_len >= args.prefill_ar_len
1084+
), "Please ensure max_seq_len is >= prefill_ar_len"
1085+
assert args.max_seq_len > next_power_of_two(
1086+
(args.window + args.gcap) * (args.ngram - 1)
1087+
), "Please ensure max_seq_len is > next_power_of_two((args.window + args.gcap) * (args.ngram - 1))"
1088+
pte_filename = "lookahead_llama_qnn"
10261089
else:
10271090
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
10281091

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,24 @@ DEFINE_int32(
5353
DEFINE_int32(
5454
eval_mode,
5555
0,
56-
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
56+
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
5757
DEFINE_string(
5858
kv_updater,
59-
"How to update kv cache. Choose between SmartMask and ShiftPointer",
60-
"SmartMask");
59+
"SmartMask",
60+
"How to update kv cache. Choose between SmartMask and ShiftPointer");
6161
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
62+
DEFINE_int32(
63+
ngram,
64+
0,
65+
"[Lookahead Decoding] Represents the size of the n-grams used in the lookahead process.");
66+
DEFINE_int32(
67+
window,
68+
0,
69+
"[Lookahead Decoding] Determines how many future tokens the algorithm attempts to predict in each step.");
70+
DEFINE_int32(
71+
gcap,
72+
0,
73+
"[Lookahead Decoding] Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.");
6274

6375
std::vector<std::string> CollectPrompts(int argc, char** argv) {
6476
// Collect all prompts from command line, example usage:
@@ -111,7 +123,10 @@ int main(int argc, char** argv) {
111123
FLAGS_performance_output_path.c_str(),
112124
FLAGS_temperature,
113125
FLAGS_eval_mode,
114-
FLAGS_kv_updater);
126+
FLAGS_kv_updater,
127+
FLAGS_ngram,
128+
FLAGS_window,
129+
FLAGS_gcap);
115130
auto llama_version = runner.get_llama_version();
116131
std::vector<char> buf;
117132
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char

examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ void KVManager::init_attention_mask(
5151
int32_t ar_len,
5252
int32_t n_past) {
5353
ET_CHECK_MSG(
54-
attention_map.size() == ar_len,
54+
attention_map.size() <= ar_len,
5555
"The size of attention_map (%zu) doesn't match with ar_len (%d)",
5656
attention_map.size(),
5757
ar_len);
@@ -197,9 +197,11 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) {
197197
? 0
198198
: metadata_.max_cache_len - (metadata_.context_len - cur_ar_len_);
199199
v_cache_[layer][head].buffer = single_layer_v_cache +
200-
head * single_head_size_in + cache_gap * metadata_.head_dim;
201-
v_cache_[layer][head].output_buffer =
202-
single_layer_v_cache + (head + 1) * single_head_size_in;
200+
head * metadata_.head_dim * metadata_.context_len +
201+
cache_gap * metadata_.head_dim;
202+
v_cache_[layer][head].output_buffer = single_layer_v_cache +
203+
head * metadata_.head_dim * metadata_.context_len +
204+
single_head_size_in;
203205
}
204206
}
205207
break;
@@ -311,21 +313,29 @@ bool KVManager::update_cache_tensor(
311313
return updated;
312314
}
313315

314-
void KVManager::update_cache(int32_t ar_len, int32_t n_past, int32_t n_update) {
316+
void KVManager::update_cache(
317+
int32_t ar_len,
318+
int32_t n_past,
319+
int32_t n_update,
320+
const std::vector<bool>& selected) {
315321
ET_CHECK_MSG(
316322
cur_ar_len_ == ar_len,
317323
"Current AR length (%d) is not matched with target AR length (%d). Please rearrange cache first.",
318324
cur_ar_len_,
319325
ar_len);
320326
for (int layer = 0; layer < metadata_.num_layers; ++layer) {
321327
for (int head = 0; head < metadata_.num_heads; ++head) {
322-
update_key(k_cache_[layer][head], n_past, n_update);
323-
update_value(v_cache_[layer][head], n_past, n_update);
328+
update_key(k_cache_[layer][head], n_past, n_update, selected);
329+
update_value(v_cache_[layer][head], n_past, n_update, selected);
324330
}
325331
}
326332
}
327333

328-
void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
334+
void KVManager::update_key(
335+
KVCache& k_cache,
336+
int32_t n_past,
337+
int32_t n_update,
338+
const std::vector<bool>& selected) {
329339
uint8_t* write_ptr = k_cache.buffer;
330340
uint8_t* read_ptr = k_cache.output_buffer;
331341
const int32_t copy_size = n_update * sizeof(uint8_t);
@@ -340,22 +350,35 @@ void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
340350
write_ptr += iter_size + past_size;
341351
if (kv_updater_ == KVManagerMode::SMART_MASK)
342352
write_ptr += past_size;
343-
344-
for (int i = 0; i < n_iter; ++i) {
345-
std::memcpy(write_ptr, read_ptr, copy_size);
346-
write_ptr += iter_size;
347-
read_ptr += out_size;
353+
if (selected.empty()) {
354+
for (int i = 0; i < n_iter; ++i) {
355+
std::memcpy(write_ptr, read_ptr, copy_size);
356+
write_ptr += iter_size;
357+
read_ptr += out_size;
358+
}
359+
} else {
360+
std::vector<int32_t> true_indices(n_update);
361+
for (int i = 0, j = 0; i < selected.size() && j < n_update; ++i) {
362+
if (selected[i]) {
363+
true_indices[j++] = i;
364+
}
365+
}
366+
for (int i = 0; i < n_iter; ++i) {
367+
auto wp = write_ptr, rp = read_ptr;
368+
for (auto ind : true_indices) {
369+
*wp++ = rp[ind];
370+
}
371+
write_ptr += iter_size;
372+
read_ptr += out_size;
373+
}
348374
}
349375
}
350376

351377
void KVManager::update_value(
352378
KVCache& v_cache,
353379
int32_t n_past,
354-
int32_t n_update) {
355-
// Value cache doesn't need to copy for SHIFT_POINTER mode
356-
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
357-
return;
358-
380+
int32_t n_update,
381+
const std::vector<bool>& selected) {
359382
uint8_t* write_ptr = v_cache.buffer;
360383
uint8_t* read_ptr = v_cache.output_buffer;
361384
const int32_t copy_size = n_update * metadata_.head_dim * sizeof(uint8_t);
@@ -364,7 +387,31 @@ void KVManager::update_value(
364387
if (kv_updater_ == KVManagerMode::SMART_MASK)
365388
write_ptr += past_size;
366389

367-
std::memcpy(write_ptr, read_ptr, copy_size);
390+
// Update the value cache for lookahead decoding in SHIFT_POINTER mode
391+
if (kv_updater_ == KVManagerMode::SHIFT_POINTER) {
392+
read_ptr += past_size;
393+
write_ptr = read_ptr;
394+
}
395+
396+
if (selected.empty()) {
397+
// In general, value cache doesn't need to copy for SHIFT_POINTER mode
398+
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
399+
return;
400+
std::memcpy(write_ptr, read_ptr, copy_size);
401+
} else {
402+
int32_t update_times = n_update;
403+
auto wp = write_ptr, rp = read_ptr;
404+
for (auto sel : selected) {
405+
if (sel) {
406+
std::memcpy(wp, rp, metadata_.head_dim * sizeof(uint8_t));
407+
wp += metadata_.head_dim;
408+
update_times--;
409+
if (update_times == 0)
410+
break;
411+
}
412+
rp += metadata_.head_dim;
413+
}
414+
}
368415
}
369416

370417
} // namespace example

0 commit comments

Comments
 (0)