File tree Expand file tree Collapse file tree 2 files changed +8
-4
lines changed
src/cpp/src/speculative_decoding Expand file tree Collapse file tree 2 files changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -815,7 +815,8 @@ ContinuousBatchingPipeline::EagleDecodingImpl::EagleDecodingImpl(const ov::genai
815815GenerationHandle ContinuousBatchingPipeline::EagleDecodingImpl::add_request (
816816 uint64_t request_id,
817817 const ov::Tensor& input_ids,
818- ov::genai::GenerationConfig sampling_params) {
818+ ov::genai::GenerationConfig sampling_params,
819+ std::optional<ov::Tensor> token_type_ids) {
819820 std::lock_guard<std::mutex> lock (m_draft_generations_mutex);
820821 auto draft_sampling_params = sampling_params;
821822 draft_sampling_params.ignore_eos = true ;
@@ -973,7 +974,8 @@ ov::Tensor ContinuousBatchingPipeline::EagleDecodingImpl::create_draft_input_ids
973974std::vector<EncodedGenerationResult> ContinuousBatchingPipeline::EagleDecodingImpl::generate (
974975 const std::vector<ov::Tensor>& input_ids,
975976 const std::vector<GenerationConfig>& sampling_params,
976- const StreamerVariant& streamer) {
977+ const StreamerVariant& streamer,
978+ std::optional<std::vector<ov::Tensor>> token_type_ids) {
977979 m_perf_metrics = ov::genai::SDPerModelsPerfMetrics ();
978980 m_draft_pipeline->raw_perf_metrics .m_inference_durations = {{ MicroSeconds (0 .0f ) }};
979981
Original file line number Diff line number Diff line change @@ -111,7 +111,8 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP
111111
112112 GenerationHandle add_request (uint64_t request_id,
113113 const ov::Tensor& input_ids,
114- ov::genai::GenerationConfig sampling_params) override ;
114+ ov::genai::GenerationConfig sampling_params,
115+ std::optional<ov::Tensor> token_type_ids = std::nullopt ) override ;
115116 GenerationHandle add_request (uint64_t request_id,
116117 const std::string& prompt,
117118 ov::genai::GenerationConfig sampling_params) override ;
@@ -129,7 +130,8 @@ class ContinuousBatchingPipeline::EagleDecodingImpl : public ContinuousBatchingP
129130 std::vector<EncodedGenerationResult>
130131 generate (const std::vector<ov::Tensor>& input_ids,
131132 const std::vector<GenerationConfig>& sampling_params,
132- const StreamerVariant& streamer) override ;
133+ const StreamerVariant& streamer,
134+ std::optional<std::vector<ov::Tensor>> token_type_ids = std::nullopt ) override ;
133135
134136 SpeculativeDecodingMetrics get_speculative_decoding_metrics ();
135137};
You can’t perform that action at this time.
0 commit comments