Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
b28286a
top-k generator for draft model
Jul 7, 2025
28a73fa
updated version
songbell Jul 8, 2025
fd58e61
v2
songbell Jul 8, 2025
6d728ba
support eagle2
songbell Jul 28, 2025
d943850
fix cb issue
songbell Jul 30, 2025
d940b43
opt inital state indexing
songbell Jul 30, 2025
bcc5e2f
handle request finish
songbell Jul 31, 2025
30e9e8b
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
songbell Jul 31, 2025
476af89
fix merge conflict
songbell Jul 31, 2025
5180e37
fix issue
songbell Jul 31, 2025
36e05eb
remove redundant copy
songbell Aug 4, 2025
9ca6539
commit stash content
songbell Aug 5, 2025
89cd543
clear sample interface
songbell Aug 5, 2025
fecbebf
add common transformation for main model
songbell Aug 6, 2025
1553bcd
eagle3 support
songbell Aug 7, 2025
25d8c09
enable perf metrics
songbell Aug 13, 2025
5eaa669
fix Segmentation fault for eagle2
xufang-lisa Aug 13, 2025
a41ea6c
optimize sort function
xufang-lisa Aug 14, 2025
cc22df1
fix no accepted rate for depth>=3
xufang-lisa Aug 15, 2025
23a0d4f
customize parameter
songbell Aug 18, 2025
58bf01b
Merge branch 'bell/eagle2_part1' of https://github.com/songbell/openv…
songbell Aug 18, 2025
ff99e82
eagle tree fully expanding
songbell Aug 20, 2025
0b5e47d
temp solution for eagle3 draft accuracy
songbell Aug 21, 2025
7224e11
enable python test interface
songbell Aug 22, 2025
52862a7
fix windows issue
songbell Aug 22, 2025
49f1377
total token - 1
songbell Aug 27, 2025
a0b37e8
optimize sampler of main model
xufang-lisa Aug 27, 2025
40869fc
optimize sampler of draft model
xufang-lisa Aug 28, 2025
bbe82da
fix cadidates sort
xufang-lisa Aug 28, 2025
670d2c5
Merge branch 'master' into bell/eagle2_part1
songbell Sep 2, 2025
dcd5c56
resolve conflict
songbell Sep 2, 2025
4ef9b60
resolve block allocation issue
songbell Sep 5, 2025
1024a6c
add total perf_metric printing for eagle pipeline
xufang-lisa Sep 15, 2025
553d2b9
fix compilation errors
xufang-lisa Sep 15, 2025
0c33718
Merge branch 'master' into bell/eagle2_part1
songbell Sep 15, 2025
d529528
add max_new_tokens settting
xufang-lisa Sep 15, 2025
8a0728d
add c++ sample README
xufang-lisa Sep 15, 2025
03ec40f
align python with c++ tokenizer
songbell Sep 15, 2025
0e3acb9
Merge branch 'bell/eagle2_part1' of https://github.com/songbell/openv…
songbell Sep 15, 2025
5522cc9
add readme
songbell Sep 15, 2025
607aea2
Merge branch 'bell/eagle2_part1' of https://github.com/songbell/openv…
songbell Sep 15, 2025
efe0cb6
Revert "add readme"
songbell Sep 15, 2025
184d1e1
eagle3 python readme
songbell Sep 15, 2025
217475b
add test parameters setting for eagle pipeline
xufang-lisa Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions samples/cpp/text_generation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ find_package(OpenVINOGenAI REQUIRED
)

function(add_sample_executable target_name)
add_executable(${target_name} ${target_name}.cpp)
add_executable(${target_name} ${target_name}.cpp read_prompt_from_file.cpp)
target_link_libraries(${target_name} PRIVATE openvino::genai)
set_target_properties(${target_name} PROPERTIES
# Ensure out-of-box LC_RPATH on macOS with SIP
Expand All @@ -29,7 +29,8 @@ set (SAMPLE_LIST
lora_greedy_causal_lm
multinomial_causal_lm
prompt_lookup_decoding_lm
speculative_decoding_lm)
speculative_decoding_lm
eagle_speculative_lm)

foreach(sample IN LISTS SAMPLE_LIST)
add_sample_executable(${sample})
Expand Down
31 changes: 31 additions & 0 deletions samples/cpp/text_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,37 @@ Recommended models: `meta-llama/Llama-3.2-1B-Instruct`, `meta-llama/Llama-3.2-8B
**Note:**
Structured output enforcement ensures valid JSON formatting, but does not guarantee factual accuracy or meaningfulness. The model may generate plausible-looking JSON with incorrect or nonsensical data (e.g., `{"explanation": "John", "output": 200000}` or `{"final_answer": "AbrakaKadabra9999######4242"}`). For best results, use the latest or fine-tuned models to improve output quality and relevance.

### 6. Eagle Speculative LM (`eagle_speculative_lm`)
- **Description:**
EAGLE is a lossless acceleration algorithm for LLM inference.

- **Convert model**
If you have your own draft model, you can refer to https://jira.devtools.intel.com/browse/CVS-171947 to convert the model.
We currently have a set of converted models which you can download (password: openvino):
``` bash
scp -r [email protected]:~/bell/speculative_decoding/eagle3/llama-3.1-8b-instruct-ov-int4/ your_path_to_main/
scp -r [email protected]:~/bell/speculative_decoding/eagle3/EAGLE3-LLaMA3.1-instruct-8B-ov-int4/ your_path_to_draft/
```

- **Run Command:**
Linux:
```bash
source <OpenVINO_install_path>/setupvars.sh
./eagle_speculative_lm <MODEL_DIR> <DRAFT_MODEL_DIR> <MAX_NEW_TOKENS> <DEPTH> "<PROMPT>"
```
Windows:
```bash
<OpenVINO_install_path>/setupvars.bat
eagle_speculative_lm.exe <MODEL_DIR> <DRAFT_MODEL_DIR> <MAX_NEW_TOKENS> <DEPTH> "<PROMPT>"
```

- **Benchmark Tools**
``` bash
scp [email protected]:~/xufang/run_eagle_base.py your_path_to_tool
scp -r [email protected]:~/xufang/data your_path_to_test_datasets
python run_eagle_base.py
```

## Troubleshooting

### Unicode characters encoding error on Windows
Expand Down
6 changes: 3 additions & 3 deletions samples/cpp/text_generation/beam_search_causal_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config;
config.max_new_tokens = 20;
config.num_beam_groups = 3;
config.num_beams = 15;
config.diversity_penalty = 1.0f;
config.num_beam_groups = 1;
config.num_beams = 2;
//config.diversity_penalty = 1.0f;
config.num_return_sequences = config.num_beams;

auto beams = pipe.generate(prompts, config);
Expand Down
125 changes: 125 additions & 0 deletions samples/cpp/text_generation/eagle_speculative_lm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <chrono>
#include <openvino/openvino.hpp>

#include "openvino/genai/llm_pipeline.hpp"
#include "openvino/genai/speculative_decoding/perf_metrics.hpp"
#include "read_prompt_from_file.h"

template <typename T>
void print_perf_metrics(T& perf_metrics, std::string model_name) {
std::cout << "\n" << model_name << std::endl;
auto generation_duration = perf_metrics.get_generate_duration().mean;
std::cout << " Generate time: " << generation_duration << " ms" << std::endl;
std::cout << " TTFT: " << perf_metrics.get_ttft().mean << " ± " << perf_metrics.get_ttft().std << " ms"
<< std::endl;
std::cout << " TPOT: " << perf_metrics.get_tpot().mean << " ± " << perf_metrics.get_tpot().std << " ms/token"
<< std::endl;
std::cout << " Num generated token: " << perf_metrics.get_num_generated_tokens() << " tokens" << std::endl;
if (model_name == "Total") {
std::cout << " Total iteration number: " << perf_metrics.raw_metrics.m_new_token_times.size() << std::endl;
} else {
std::cout << " Total iteration number: " << perf_metrics.raw_metrics.m_durations.size() << std::endl;
}
if (perf_metrics.get_num_input_tokens() > 0) {
std::cout << " Input token size: " << perf_metrics.get_num_input_tokens() << std::endl;
}
}

int main(int argc, char* argv[]) try {
if (6 != argc) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> <EAGLE_MODEL_DIR> <MAX_NEW_TOKENS> <DEPTH> '<PROMPT>'");
}

std::string main_model_path = argv[1];
std::string eagle_model_path = argv[2];
int max_new_tokens = atoi(argv[3]);
int depth = atoi(argv[4]);
std::string prompt = argv[5];
if (std::filesystem::is_regular_file(prompt)) {
std::string prompt_file = prompt;
prompt = utils::read_prompt(prompt_file);
}

// Configure devices - can run main and eagle models on different devices
std::string main_device = "GPU", eagle_device = "GPU"; // currently only GPU is used during developing

// Eagle Speculative settings
ov::genai::GenerationConfig config = ov::genai::greedy();
config.max_new_tokens = max_new_tokens;
// Eagle specific parameters
config.eagle_tree_params.branching_factor = 1; // Number of candidate tokens to consider at each level
config.eagle_tree_params.tree_depth = depth; // How deep to explore the token tree
config.eagle_tree_params.total_tokens = depth + 2; // Total number of tokens to generate in eagle tree
config.num_return_sequences = 1; // only support 1

//config.eagle_tree_width = 3; // Number of candidate tokens to consider at each level
//config.eagle_tree_depth = 4; // How deep to explore the token tree

// Create pipeline with eagle speculative enabled
ov::genai::LLMPipeline pipe(
main_model_path,
main_device,
ov::genai::draft_model(eagle_model_path, eagle_device),
std::pair<std::string, ov::Any>("eagle_mode", ov::Any("EAGLE3")) // Specify eagle3 mode for draft model
);
// Setup performance measurement
auto start_time = std::chrono::high_resolution_clock::now();

// Optional: Create a streaming callback for real-time token display
auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return ov::genai::StreamingStatus::RUNNING;
};

// Run generation with eagle speculative decoding
std::cout << "Generating with Eagle Speculative decoding:" << std::endl;
auto result = pipe.generate(prompt, config, streamer);
std::cout << std::endl;

// Calculate and display performance metrics
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
std::cout << "\nGeneration completed in " << duration.count() << " ms" << std::endl;

auto sd_perf_metrics = std::dynamic_pointer_cast<ov::genai::SDPerModelsPerfMetrics>(result.extended_perf_metrics);
if (sd_perf_metrics) {
print_perf_metrics(result.perf_metrics, "Total");
print_perf_metrics(sd_perf_metrics->main_model_metrics, "MAIN MODEL");
std::cout << " accepted token: " << sd_perf_metrics->get_num_accepted_tokens() << " tokens" << std::endl;
std::cout << " compress rate: "
<< sd_perf_metrics->main_model_metrics.get_num_generated_tokens() * 1.0f /
sd_perf_metrics->main_model_metrics.raw_metrics.m_durations.size()
<< std::endl;
print_perf_metrics(sd_perf_metrics->draft_model_metrics, "DRAFT MODEL");
}
std::cout << std::endl;

// Run without Eagle for comparison
std::cout << "\n-----------------------------" << std::endl;
std::cout << "Generating without Eagle Speculative decoding:" << std::endl;

// Disable Eagle mode
/*config.eagle_model = false;

start_time = std::chrono::high_resolution_clock::now();
pipe.generate(prompt, config, streamer);
std::cout << std::endl;
*/
end_time = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
std::cout << "\nStandard generation completed in " << duration.count() << " ms" << std::endl;

} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
} catch (const std::ios_base::failure&) {}
return EXIT_FAILURE;
} catch (...) {
try {
std::cerr << "Non-exception object thrown\n";
} catch (const std::ios_base::failure&) {}
return EXIT_FAILURE;
}
11 changes: 9 additions & 2 deletions samples/cpp/text_generation/greedy_causal_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@ int main(int argc, char* argv[]) try {
std::string models_path = argv[1];
std::string prompt = argv[2];
std::string device = "CPU"; // GPU can be used as well

ov::genai::LLMPipeline pipe(models_path, device);
ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
std::string result = pipe.generate(prompt, config);
auto start_time = std::chrono::high_resolution_clock::now();
auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return ov::genai::StreamingStatus::RUNNING;
};
std::string result = pipe.generate(prompt, config, streamer);
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
std::cout << "\nGeneration completed in " << duration.count() << " ms" << std::endl;
std::cout << result << std::endl;
} catch (const std::exception& error) {
try {
Expand Down
6 changes: 3 additions & 3 deletions samples/cpp/text_generation/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ int main(int argc, char* argv[]) try {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <MODEL_DIR> <DRAFT_MODEL_DIR> '<PROMPT>'");
}

ov::genai::GenerationConfig config;
ov::genai::GenerationConfig config = ov::genai::multinomial();
config.max_new_tokens = 100;
// Speculative decoding generation parameters like `num_assistant_tokens` and `assistant_confidence_threshold` are mutually excluded
// add parameter to enable speculative decoding to generate `num_assistant_tokens` candidates by draft_model per iteration
config.num_assistant_tokens = 5;
// add parameter to enable speculative decoding to generate candidates by draft_model while candidate probability is higher than `assistant_confidence_threshold`
// config.assistant_confidence_threshold = 0.4;

config.num_return_sequences = 1;
std::string main_model_path = argv[1];
std::string draft_model_path = argv[2];
std::string prompt = argv[3];

// User can run main and draft model on different devices.
// Please, set device for main model in `LLMPipeline` constructor and in in `ov::genai::draft_model` for draft.
std::string main_device = "CPU", draft_device = "CPU";
std::string main_device = "GPU", draft_device = "GPU";

ov::genai::LLMPipeline pipe(
main_model_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,17 @@ class OPENVINO_GENAI_EXPORTS ContinuousBatchingPipeline {
class ContinuousBatchingImpl;

class ContinuousBatchingForSpeculativeDecodingImpl;
class ContinuousBatchingForEagleDecodingImpl;
class ContinuousBatchingForPromptLookupImpl;
class SpeculativeDecodingImpl;
class EagleDecodingImpl;
class PromptLookupImpl;

friend class ContinuousBatchingForSpeculativeDecodingImpl;
friend class ContinuousBatchingForEagleDecodingImpl;
friend class ContinuousBatchingForPromptLookupImpl;
friend class SpeculativeDecodingImpl;
friend class EagleDecodingImpl;
friend class PromptLookupImpl;

std::shared_ptr<IContinuousBatchingPipeline> m_impl;
Expand Down
11 changes: 11 additions & 0 deletions src/cpp/include/openvino/genai/generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,16 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {
size_t num_assistant_tokens = 0;
size_t max_ngram_size = 0;

// eagle parameters for assisting generation
struct eagle_params {
// eagle/model/cnets.py
// total_tokens = self.total_tokens
// depth = self.depth
// top_k = self.top_k
size_t branching_factor = 1; // top-k
size_t tree_depth = 0; // How deep to look ahead, eagle tree depth, draft will run depth + 1(tree init) levels
size_t total_tokens = 1; // Total number of tokens to generate in eagle tree
} eagle_tree_params;
// Structured output parameters
std::optional<StructuredOutputConfig> structured_output_config;

Expand All @@ -346,6 +356,7 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {
bool is_multinomial() const;
bool is_assisting_generation() const;
bool is_prompt_lookup() const;
bool is_eagle_tree() const;
bool is_structured_output_generation() const;

OPENVINO_DEPRECATED("Please, use `is_assisting_generation()` instead of `is_speculative_decoding()`. This method will be removed in 2026.0.0 release")
Expand Down
12 changes: 11 additions & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,17 @@ static constexpr ov::Property<SchedulerConfig> scheduler_config{"scheduler_confi
static constexpr ov::Property<bool> prompt_lookup{"prompt_lookup"};

/**
* @brief enable enable_save_ov_model property serves to serialize ov model (xml/bin) generated from gguf model on disk for re-use.
* @brief enable eagle_mode property serves to activate eagle decoding.
* for eagle2 now
* And create LLMPipeline instance with this config.
*/
enum class EagleMode {
OFF = 0, // Default mode, no eagle2 optimizations
EAGLE2 = 1 // Enable eagle2 optimizations
};
static constexpr ov::Property<EagleMode> eagle_mode{"eagle_mode"};

/* @brief enable enable_save_ov_model property serves to serialize ov model (xml/bin) generated from gguf model on disk for re-use.
* Set `true` to activate this mode.
* And create LLMPipeline instance with this config.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/perf_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ struct OPENVINO_GENAI_EXPORTS PerfMetrics {
MeanStdPair detokenization_duration = {-1.0f, -1.0f};

size_t num_generated_tokens;
size_t num_input_tokens;
size_t num_input_tokens = 0;

float get_load_time(); // Load time in ms.
size_t get_num_generated_tokens();
Expand Down
41 changes: 37 additions & 4 deletions src/cpp/src/continuous_batching/block_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,14 @@ class BlockManager {
continue;
last_block_ids.insert(last_block_id);

size_t needed_blocks_per_sequence = seq_group->get_num_logical_blocks() - num_physical_blocks;
// here we need to expand to 2 stages: 1 generation stage for causal LLM and num_validate_tokens stage for spec decode
size_t logical_blocks = seq_group->get_num_logical_blocks();
size_t logical_blocks_with_1_generation_only = logical_blocks;
if (seq_group->get_num_tokens_to_validate() > 0) {
logical_blocks_with_1_generation_only = seq_group->get_num_logical_blocks_for_1_generation();
}

size_t needed_blocks_per_sequence = (logical_blocks == logical_blocks_with_1_generation_only ? logical_blocks : logical_blocks_with_1_generation_only) - num_physical_blocks;

KVCacheBlock::Ptr last_block = block_table.back();
if (last_block->copy_on_write()) {
Expand All @@ -981,11 +988,15 @@ class BlockManager {
else {
blocks_count += needed_blocks_per_sequence * references_count;
}
}
else {
} else {
// block is used only by one sequence
blocks_count += needed_blocks_per_sequence;
}
if (seq_group->get_num_tokens_to_validate() > 0) {
// now we need to allocate blocks for num_tokens_to_validate
size_t needed_blocks_extra = logical_blocks - logical_blocks_with_1_generation_only;
blocks_count += needed_blocks_extra * last_block->get_references_count();
}
}
return blocks_count;
}
Expand All @@ -1009,7 +1020,29 @@ class BlockManager {
}
}
}
void allocate_slots_for_validation(SequenceGroup::Ptr seq_group) {
std::lock_guard<std::mutex> lock(m_cached_blocks_map_mutex);
size_t num_logical_blocks = seq_group->get_num_logical_blocks();
std::vector<Sequence::Ptr> running_sequences = seq_group->get_running_sequences();

for (size_t i = 0; i < running_sequences.size(); ++i) {
Sequence::Ptr sequence = running_sequences[i];
auto seq_id = sequence->get_id();
size_t num_physical_blocks = 0;

if (m_block_table.find(seq_id) != m_block_table.end())
{
num_physical_blocks = m_block_table[seq_id][0].size();
}

if (num_logical_blocks > num_physical_blocks) {
OPENVINO_ASSERT(can_allocate_blocks(num_logical_blocks - num_physical_blocks));
allocate(sequence, num_logical_blocks - num_physical_blocks, seq_group->get_prompt_len());
} else {
OPENVINO_ASSERT(num_logical_blocks == num_physical_blocks, "A number of physical and logic blocks must be the same in this code path");
}
}
}

/**
* Allocates just enough physical KV cache blocks to a sequence group to be enough for the sequences in it. If the sequences
Expand All @@ -1023,7 +1056,7 @@ class BlockManager {
std::lock_guard<std::mutex> lock(m_cached_blocks_map_mutex);
// Will always allocate the identical number of new blocks (if any) to each of the "layers" to keep the
// number of blocks occupied by each "layer" identical at all times.
size_t num_logical_blocks = seq_group->get_num_logical_blocks();
size_t num_logical_blocks = seq_group->get_num_tokens_to_validate() > 0 ? seq_group->get_num_logical_blocks_for_1_generation() :seq_group->get_num_logical_blocks();
std::vector<Sequence::Ptr> running_sequences = seq_group->get_running_sequences();

std::map<size_t, std::list<size_t>> copy_blocks_map;
Expand Down
Loading
Loading