1- // Copyright 2019-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ // Copyright 2019-2025 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22//
33// Redistribution and use in source and binary forms, with or without
44// modification, are permitted provided that the following conditions
@@ -45,17 +45,42 @@ class EnsembleContext;
4545
4646using IterationCount = size_t ;
4747
48+ // Check if the model is configured to preserve the order of responses.
49+ // This is critical for async execution of ResponseComplete callbacks.
50+ inline bool
51+ preserve_responses_order (const inference::ModelConfig& config)
52+ {
53+ uint64_t total_instance_groups = 0 ;
54+ for (const auto & group : config.instance_group ()) {
55+ total_instance_groups += group.count ();
56+ }
57+
58+ // Case 1: Sequence batching is enabled
59+ // Case 2: Dynamic batching is disabled and there is only one instance group
60+ // Case 3: Dynamic batching is enabled and preserve_ordering is true
61+ // Case 4: Model transaction policy is decoupled (if the final response
62+ // callback is not executed in the last step, the RequestTracker object will
63+ // be freed prematurely and led to segmentation fault)
64+ return config.has_sequence_batching () ||
65+ (!config.has_dynamic_batching () && total_instance_groups <= 1 ) ||
66+ (config.has_dynamic_batching () &&
67+ config.dynamic_batching ().preserve_ordering ()) ||
68+ config.model_transaction_policy ().decoupled ();
69+ }
70+
4871// Request tracker is passed as 'userp' in RequestRelease function and used
4972// to manage the lifecycle of the ensemble request
5073class RequestTracker {
5174 public:
5275 explicit RequestTracker (
5376 std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
5477 MetricModelReporter* metric_reporter,
55- InferenceStatsAggregator* stats_aggregator)
78+ InferenceStatsAggregator* stats_aggregator,
79+ triton::common::ThreadPool* callback_pool)
5680 : inflight_request_counter_(1 ), request_(std::move(request)),
5781 compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
58- stats_aggregator_(stats_aggregator), status_(Status::Success)
82+ stats_aggregator_(stats_aggregator), status_(Status::Success),
83+ callback_pool_(callback_pool)
5984 {
6085 }
6186
@@ -70,6 +95,8 @@ class RequestTracker {
7095 return context_stats_aggregator_;
7196 }
7297
98+ triton::common::ThreadPool* CallbackPool () const { return callback_pool_; }
99+
73100 void IncrementCounter ()
74101 {
75102 std::lock_guard<std::mutex> lk (mtx_);
@@ -120,6 +147,7 @@ class RequestTracker {
120147 InferenceStatsAggregator* stats_aggregator_;
121148 InferenceStatsAggregator context_stats_aggregator_;
122149 Status status_;
150+ triton::common::ThreadPool* const callback_pool_;
123151};
124152
125153// Step is used as 'userp' and keeps ensemble context alive
@@ -129,9 +157,9 @@ class RequestTracker {
129157struct Step {
130158 Step (
131159 size_t step_idx, const InferenceRequest::SequenceId& correlation_id,
132- uint32_t flags)
160+ uint32_t flags, bool preserve_responses_order )
133161 : correlation_id_(correlation_id), flags_(flags), response_flags_(0 ),
134- step_idx_ (step_idx)
162+ preserve_responses_order_ (preserve_responses_order), step_idx_(step_idx)
135163 {
136164 }
137165
@@ -154,7 +182,7 @@ struct Step {
154182 // returning from the callback.
155183 uint32_t response_flags_;
156184 TRITONSERVER_InferenceResponse* response_;
157-
185+ const bool preserve_responses_order_;
158186
159187 size_t step_idx_;
160188};
@@ -237,7 +265,7 @@ class EnsembleContext {
237265 MetricModelReporter* metric_reporter,
238266 InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
239267 EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
240- cudaStream_t stream);
268+ cudaStream_t stream, triton::common::ThreadPool* callback_pool );
241269
242270 // Perform transition on 'context' state given the information of
243271 // 'completed_step'
@@ -326,6 +354,8 @@ class EnsembleContext {
326354 void CacheEnsembleTopLevelRequest (
327355 std::unique_ptr<InferenceResponse>& response);
328356
357+ triton::common::ThreadPool* CallbackPool () const { return callback_pool_; }
358+
329359 InferenceServer* is_;
330360
331361 EnsembleInfo* info_;
@@ -375,20 +405,26 @@ class EnsembleContext {
375405 TRITONSERVER_ResponseAllocator,
376406 decltype (&TRITONSERVER_ResponseAllocatorDelete)>
377407 allocator_;
408+
409+ // The thread pool used to execute ensemble callbacks and reduce e2e latency.
410+ // The thread pool is managed by InferenceServer.
411+ triton::common::ThreadPool* const callback_pool_;
378412};
379413
380414EnsembleContext::EnsembleContext (
381415 MetricModelReporter* metric_reporter,
382416 InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
383417 EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
384- cudaStream_t stream)
418+ cudaStream_t stream, triton::common::ThreadPool* callback_pool )
385419 : is_(is), info_(info), stream_(stream), inflight_step_counter_(0 ),
386- allocator_(nullptr , TRITONSERVER_ResponseAllocatorDelete)
420+ allocator_(nullptr , TRITONSERVER_ResponseAllocatorDelete),
421+ callback_pool_(callback_pool)
387422{
388423 uint64_t compute_start_ns = 0 ;
389424 INFER_STATS_SET_TIMESTAMP (compute_start_ns);
390425 request_tracker_ = new RequestTracker (
391- std::move (request), compute_start_ns, metric_reporter, stats_aggregator);
426+ std::move (request), compute_start_ns, metric_reporter, stats_aggregator,
427+ callback_pool);
392428
393429 auto & lrequest = request_tracker_->Request ();
394430
@@ -603,29 +639,57 @@ void
603639EnsembleContext::RequestComplete (
604640 TRITONSERVER_InferenceRequest* request, const uint32_t flags, void * userp)
605641{
606- if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0 ) {
607- LOG_TRITONSERVER_ERROR (
608- TRITONSERVER_InferenceRequestDelete (request),
609- " deleting ensemble inference request" );
610- auto request_tracker = reinterpret_cast <RequestTracker*>(userp);
611- if (request_tracker->DecrementCounter ()) {
612- delete request_tracker;
642+ auto request_tracker = reinterpret_cast <RequestTracker*>(userp);
643+ auto pool = request_tracker->CallbackPool ();
644+ auto fn = [request, flags, request_tracker]() {
645+ if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0 ) {
646+ LOG_TRITONSERVER_ERROR (
647+ TRITONSERVER_InferenceRequestDelete (request),
648+ " deleting ensemble inference request" );
649+ if (request_tracker->DecrementCounter ()) {
650+ delete request_tracker;
651+ }
613652 }
653+ };
654+
655+ // Attempt to enqueue the callback. If all workers are busy and queue is at
656+ // capacity, execute the callback immediately in current thread.
657+ if (pool->TaskQueueSize () < pool->Size ()) {
658+ pool->Enqueue (fn);
659+ } else {
660+ fn ();
614661 }
615662}
616663
617664void
618665EnsembleContext::ResponseComplete (
619666 TRITONSERVER_InferenceResponse* response, const uint32_t flags, void * userp)
620667{
621- auto step_ptr = std::unique_ptr<Step>(reinterpret_cast <Step*>(userp));
622- step_ptr->response_flags_ = flags;
623- step_ptr->response_ = response;
624-
625- EnsembleContext::Proceed (step_ptr->ctx_ , step_ptr);
626- // Expecting more responses
627- if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0 ) {
628- step_ptr.release ();
668+ auto step_raw_ptr = reinterpret_cast <Step*>(userp);
669+ auto pool = step_raw_ptr->ctx_ ->CallbackPool ();
670+ auto fn = [response, flags, step_raw_ptr]() {
671+ auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
672+ step_ptr->response_flags_ = flags;
673+ step_ptr->response_ = response;
674+
675+ EnsembleContext::Proceed (step_ptr->ctx_ , step_ptr);
676+ // Expecting more responses
677+ if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0 ) {
678+ step_ptr.release ();
679+ }
680+ };
681+
682+ // Attempt to enqueue the callback. If all workers are busy and queue is at
683+ // capacity, execute the callback immediately in current thread.
684+ // Note: The async callback optimization does not guarantee the order of
685+ // responses and expolit cases where responses can be out-of-order. For models
686+ // required to preserve the order of responses, the response callbacks must be
687+ // executed in the same thread synchronously.
688+ if (!step_raw_ptr->preserve_responses_order_ &&
689+ pool->TaskQueueSize () < pool->Size ()) {
690+ pool->Enqueue (fn);
691+ } else {
692+ fn ();
629693 }
630694}
631695
@@ -971,8 +1035,8 @@ EnsembleContext::InitStep(
9711035 for (const auto & pair : istep.output_to_tensor_ ) {
9721036 irequest->AddOriginalRequestedOutput (pair.first );
9731037 }
974-
975- step->reset (new Step (step_idx, correlation_id, flags));
1038+ const bool preserve_order = preserve_responses_order (model-> Config ());
1039+ step->reset (new Step (step_idx, correlation_id, flags, preserve_order ));
9761040
9771041 irequest->SetId (request_id_);
9781042 irequest->SetCorrelationId (correlation_id);
@@ -1448,7 +1512,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
14481512 RETURN_IF_ERROR (request->SetState (InferenceRequest::State::EXECUTING));
14491513 std::shared_ptr<EnsembleContext> context (new EnsembleContext (
14501514 metric_reporter_.get (), stats_aggregator_, is_, info_.get (), request,
1451- stream_));
1515+ stream_, callback_pool_ ));
14521516 EnsembleContext::Proceed (context);
14531517 return Status::Success;
14541518}
@@ -1537,6 +1601,7 @@ EnsembleScheduler::EnsembleScheduler(
15371601 info_->tensor_to_prev_step_ .emplace (pair.second , step_idx);
15381602 }
15391603 }
1604+ callback_pool_ = is_->EnsembleCallbackPool ();
15401605}
15411606
15421607EnsembleScheduler::~EnsembleScheduler ()
0 commit comments