@@ -245,8 +245,8 @@ class AllreduceOp
245245 , mType (type)
246246 , mStrategy (strategy)
247247 , mOp (op)
248- , mEps (eps)
249248 , mIsMNNVLSupported (false )
249+ , mEps (eps)
250250 {
251251 }
252252
@@ -256,8 +256,8 @@ class AllreduceOp
256256 , mType (type)
257257 , mStrategy (strategy)
258258 , mOp (op)
259- , mEps (eps)
260259 , mIsMNNVLSupported (false )
260+ , mEps (eps)
261261 , mNcclComm (process_group_)
262262 {
263263 }
@@ -491,6 +491,7 @@ class AllreduceOp
491491 NCCLCHECK_THROW (ncclCommCount (comm, &nRanks));
492492 size_t minRegistrationThreshold = static_cast <size_t >(std::max (0.0 , a * nRanks + b)) * input.element_size ();
493493 // Disable window registration if neither NVLink nor MNNVL is supported
494+ // TODO replace in NCCL 2.29 with comm query
494495 if (!mIsNVLINKSupported && !mIsMNNVLSupported )
495496 {
496497 minRegistrationThreshold = std::numeric_limits<size_t >::max ();
@@ -904,8 +905,8 @@ class AllreduceOp
904905
905906 // 2. Check multicast support
906907 CUdevice cu_device;
907- auto & cuda_driver = tensorrt_llm::common::CUDADriverWrapper::getInstance ( );
908- TLLM_CU_CHECK ( cuda_driver-> cuDeviceGet (&cu_device, device_id) );
908+ TLLM_CU_CHECK ( cuDeviceGet (&cu_device, device_id) );
909+ auto cuda_driver = tensorrt_llm::common::CUDADriverWrapper::getInstance ( );
909910
910911 int multicast_supported = 0 ;
911912 TLLM_CU_CHECK (cuda_driver->cuDeviceGetAttribute (
@@ -1139,18 +1140,16 @@ class AllreduceOp
11391140 },
11401141 [&](c10::intrusive_ptr<c10d::ProcessGroup>& torchPg)
11411142 {
1142- // For ProcessGroup, use allgather
1143- // Create a sub-group for the ranks in mGroup
1144- std::vector<int > group_ranks (mGroup .begin (), mGroup .end ());
1145- auto group_pg = torchPg->newGroup (group_ranks);
1146- if (group_pg)
1143+ // For ProcessGroup, use allgather directly
1144+ // Note: This assumes the ProcessGroup is already set up for the correct group
1145+ std::vector<torch::Tensor> input_tensors
1146+ = {torch::tensor ({local_mnnvl_status}, torch::kInt32 )};
1147+ std::vector<std::vector<torch::Tensor>> output_tensors (1 );
1148+ output_tensors[0 ].resize (mGroup .size ());
1149+ auto work = torchPg->allgather (output_tensors, input_tensors);
1150+ if (work)
11471151 {
1148- std::vector<torch::Tensor> input_tensors
1149- = {torch::tensor ({local_mnnvl_status}, torch::kInt32 )};
1150- std::vector<std::vector<torch::Tensor>> output_tensors (1 );
1151- output_tensors[0 ].resize (mGroup .size ());
1152- group_pg->allgather (output_tensors, input_tensors)->wait ();
1153-
1152+ work->wait ();
11541153 for (size_t i = 0 ; i < mGroup .size (); ++i)
11551154 {
11561155 all_mnnvl_status[i] = output_tensors[0 ][i].item <int >();
0 commit comments