Skip to content

Commit 6a85c8e

Browse files
committed
fixes
Signed-off-by: Ludwig Schneider <[email protected]>
1 parent aca32d7 commit 6a85c8e

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ class AllreduceOp
245245
, mType(type)
246246
, mStrategy(strategy)
247247
, mOp(op)
248-
, mEps(eps)
249248
, mIsMNNVLSupported(false)
249+
, mEps(eps)
250250
{
251251
}
252252

@@ -256,8 +256,8 @@ class AllreduceOp
256256
, mType(type)
257257
, mStrategy(strategy)
258258
, mOp(op)
259-
, mEps(eps)
260259
, mIsMNNVLSupported(false)
260+
, mEps(eps)
261261
, mNcclComm(process_group_)
262262
{
263263
}
@@ -491,6 +491,7 @@ class AllreduceOp
491491
NCCLCHECK_THROW(ncclCommCount(comm, &nRanks));
492492
size_t minRegistrationThreshold = static_cast<size_t>(std::max(0.0, a * nRanks + b)) * input.element_size();
493493
// Disable window registration if neither NVLink nor MNNVL is supported
494+
// TODO replace in NCCL 2.29 with comm query
494495
if (!mIsNVLINKSupported && !mIsMNNVLSupported)
495496
{
496497
minRegistrationThreshold = std::numeric_limits<size_t>::max();
@@ -904,8 +905,8 @@ class AllreduceOp
904905

905906
// 2. Check multicast support
906907
CUdevice cu_device;
907-
auto& cuda_driver = tensorrt_llm::common::CUDADriverWrapper::getInstance();
908-
TLLM_CU_CHECK(cuda_driver->cuDeviceGet(&cu_device, device_id));
908+
TLLM_CU_CHECK(cuDeviceGet(&cu_device, device_id));
909+
auto cuda_driver = tensorrt_llm::common::CUDADriverWrapper::getInstance();
909910

910911
int multicast_supported = 0;
911912
TLLM_CU_CHECK(cuda_driver->cuDeviceGetAttribute(
@@ -1139,18 +1140,16 @@ class AllreduceOp
11391140
},
11401141
[&](c10::intrusive_ptr<c10d::ProcessGroup>& torchPg)
11411142
{
1142-
// For ProcessGroup, use allgather
1143-
// Create a sub-group for the ranks in mGroup
1144-
std::vector<int> group_ranks(mGroup.begin(), mGroup.end());
1145-
auto group_pg = torchPg->newGroup(group_ranks);
1146-
if (group_pg)
1143+
// For ProcessGroup, use allgather directly
1144+
// Note: This assumes the ProcessGroup is already set up for the correct group
1145+
std::vector<torch::Tensor> input_tensors
1146+
= {torch::tensor({local_mnnvl_status}, torch::kInt32)};
1147+
std::vector<std::vector<torch::Tensor>> output_tensors(1);
1148+
output_tensors[0].resize(mGroup.size());
1149+
auto work = torchPg->allgather(output_tensors, input_tensors);
1150+
if (work)
11471151
{
1148-
std::vector<torch::Tensor> input_tensors
1149-
= {torch::tensor({local_mnnvl_status}, torch::kInt32)};
1150-
std::vector<std::vector<torch::Tensor>> output_tensors(1);
1151-
output_tensors[0].resize(mGroup.size());
1152-
group_pg->allgather(output_tensors, input_tensors)->wait();
1153-
1152+
work->wait();
11541153
for (size_t i = 0; i < mGroup.size(); ++i)
11551154
{
11561155
all_mnnvl_status[i] = output_tensors[0][i].item<int>();

0 commit comments

Comments
 (0)