@@ -242,21 +242,25 @@ class AllreduceOp
242242 AllreduceOp (
243243 std::set<int > group, nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps)
244244 : mGroup (std::move(group))
245+ , mIsNVLINKSupported (false )
246+ , mIsP2PSupported (false )
247+ , mIsMNNVLSupported (false )
245248 , mType (type)
246249 , mStrategy (strategy)
247250 , mOp (op)
248- , mIsMNNVLSupported (false )
249251 , mEps (eps)
250252 {
251253 }
252254
253255 AllreduceOp (std::set<int > group, c10::intrusive_ptr<c10d::ProcessGroup> const & process_group_,
254256 nvinfer1::DataType type, AllReduceStrategyType strategy, AllReduceFusionOp op, float eps)
255257 : mGroup (std::move(group))
258+ , mIsNVLINKSupported (false )
259+ , mIsP2PSupported (false )
260+ , mIsMNNVLSupported (false )
256261 , mType (type)
257262 , mStrategy (strategy)
258263 , mOp (op)
259- , mIsMNNVLSupported (false )
260264 , mEps (eps)
261265 , mNcclComm (process_group_)
262266 {
@@ -1109,20 +1113,9 @@ class AllreduceOp
11091113 std::visit (overloaded{[&](std::shared_ptr<ncclComm_t>& comm_ptr)
11101114 {
11111115 // For NCCL comm, use MPI to gather status
1112- // Map group ranks to positions
1113- std::vector<int > group_ranks (mGroup .begin (), mGroup .end ());
1114- int my_group_pos = 0 ;
1115- for (size_t i = 0 ; i < group_ranks.size (); ++i)
1116- {
1117- if (group_ranks[i] == rank)
1118- {
1119- my_group_pos = i;
1120- break ;
1121- }
1122- }
1123-
11241116 // Use MPI allgather to collect MNNVL status
11251117 // Create a sub-communicator for the group
1118+ std::vector<int > group_ranks (mGroup .begin (), mGroup .end ());
11261119 MPI_Group world_group, new_group;
11271120 MPI_Comm group_comm;
11281121 MPI_Comm_group (COMM_SESSION, &world_group);
0 commit comments