Skip to content

Commit dcd5c56

Browse files
committed
resolve conflict
Signed-off-by: fishbell <[email protected]>
1 parent 670d2c5 commit dcd5c56

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,8 @@ ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai
815815
GenerationHandle ContinuousBatchingPipeline::EagleDecodingImpl::add_request(
816816
uint64_t request_id,
817817
const ov::Tensor& input_ids,
818-
ov::genai::GenerationConfig sampling_params) {
818+
ov::genai::GenerationConfig sampling_params,
819+
std::optional<ov::Tensor> token_type_ids) {
819820
std::lock_guard<std::mutex> lock(m_draft_generations_mutex);
820821
auto draft_sampling_params = sampling_params;
821822
draft_sampling_params.ignore_eos = true;
@@ -973,7 +974,8 @@ ov::Tensor ContinuousBatchingPipeline::EagleDecodingImpl::create_draft_input_ids
973974
std::vector<EncodedGenerationResult> ContinuousBatchingPipeline::EagleDecodingImpl::generate(
974975
const std::vector<ov::Tensor>& input_ids,
975976
const std::vector<GenerationConfig>& sampling_params,
976-
const StreamerVariant& streamer) {
977+
const StreamerVariant& streamer,
978+
std::optional<std::vector<ov::Tensor>> token_type_ids) {
977979
m_perf_metrics = ov::genai::SDPerModelsPerfMetrics();
978980
m_draft_pipeline->raw_perf_metrics.m_inference_durations = {{ MicroSeconds(0.0f) }};
979981

src/cpp/src/speculative_decoding/speculative_decoding_impl.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP
111111

112112
GenerationHandle add_request(uint64_t request_id,
113113
const ov::Tensor& input_ids,
114-
ov::genai::GenerationConfig sampling_params) override;
114+
ov::genai::GenerationConfig sampling_params,
115+
std::optional<ov::Tensor> token_type_ids = std::nullopt) override;
115116
GenerationHandle add_request(uint64_t request_id,
116117
const std::string& prompt,
117118
ov::genai::GenerationConfig sampling_params) override;
@@ -129,7 +130,8 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP
129130
std::vector<EncodedGenerationResult>
130131
generate(const std::vector<ov::Tensor>& input_ids,
131132
const std::vector<GenerationConfig>& sampling_params,
132-
const StreamerVariant& streamer) override;
133+
const StreamerVariant& streamer,
134+
std::optional<std::vector<ov::Tensor>> token_type_ids = std::nullopt) override;
133135

134136
SpeculativeDecodingMetrics get_speculative_decoding_metrics();
135137
};

0 commit comments

Comments
 (0)