From 08de73408e61b0202c6f2016c727982a18dea189 Mon Sep 17 00:00:00 2001 From: Han Chao Date: Mon, 22 Sep 2025 13:23:28 +0800 Subject: [PATCH 1/2] separate comm init from getXCClComm --- src/xccl/ProcessGroupXCCL.cpp | 27 ++++++++++++++++++--------- src/xccl/ProcessGroupXCCL.hpp | 4 +++- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 923771a13f..3fd22d62f2 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -494,6 +494,16 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( } std::shared_ptr ProcessGroupXCCL::getXCCLComm( + const std::string& deviceKey) { + std::lock_guard lock(mutex_); + if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + return devXCCLCommMap_[deviceKey]; + } + return nullptr; +} + +std::shared_ptr ProcessGroupXCCL::initXCCLComm( const std::string& deviceKey, at::Device& device, OpType opType, @@ -508,13 +518,6 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( usedDeviceIdxs_.insert(device.index()); - { - std::lock_guard lock(mutex_); - if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { - return devXCCLCommMap_[deviceKey]; - } - } - std::shared_ptr XCCLComm; bool batchP2P = xcclActiveGroupCounter_ > 0; @@ -673,7 +676,10 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( seqCollective_++; auto device = inputs[0].device(); const auto key = std::to_string(device.index()); - auto comm = getXCCLComm(key, device, opType); + std::shared_ptr comm = getXCCLComm(key); + if (comm == nullptr) { + comm = initXCCLComm(key, device, opType); + } if (coalescing_state_ & CoalActive) { if ((coalescing_state_ & CoalColl) == 0) { @@ -824,7 +830,10 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( } op_id_++; - auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + std::shared_ptr comm = getXCCLComm(key); + if (comm == nullptr) { + comm = initXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + } if (coalescing_state_ & CoalActive) { if ((coalescing_state_ & CoalP2P) == 0) { diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 2516c0b912..eef1a5ade9 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -166,7 +166,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr endCoalescing(OpType optype); - std::shared_ptr getXCCLComm( + std::shared_ptr getXCCLComm(const std::string& deviceKey); + + std::shared_ptr initXCCLComm( const std::string& deviceKey, at::Device& device, OpType opType, From 284e3db5bc0ad5062fcd59d648799333a3ccac04 Mon Sep 17 00:00:00 2001 From: Chao Han Date: Thu, 9 Oct 2025 09:19:36 +0800 Subject: [PATCH 2/2] Update src/xccl/ProcessGroupXCCL.cpp Co-authored-by: Dmitry Rogozhkin --- src/xccl/ProcessGroupXCCL.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 3fd22d62f2..12bec63aba 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -496,9 +496,10 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey) { std::lock_guard lock(mutex_); - if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { + auto it = devXCCLCommMap_.find(deviceKey); + if (it != devXCCLCommMap_.end()) { // Reuse the cached communicator if there is one. - return devXCCLCommMap_[deviceKey]; + return it->second; } return nullptr; }