diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index e2bf50079..e718a50e5 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -502,6 +502,17 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( } std::shared_ptr ProcessGroupXCCL::getXCCLComm( + const std::string& deviceKey) { + std::lock_guard lock(mutex_); + auto it = devXCCLCommMap_.find(deviceKey); + if (it != devXCCLCommMap_.end()) { + // Reuse the cached communicator if there is one. + return it->second; + } + return nullptr; +} + +std::shared_ptr ProcessGroupXCCL::initXCCLComm( const std::string& deviceKey, at::Device& device, OpType opType, @@ -516,13 +527,6 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( usedDeviceIdxs_.insert(device.index()); - { - std::lock_guard lock(mutex_); - if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { - return devXCCLCommMap_[deviceKey]; - } - } - std::shared_ptr XCCLComm; bool batchP2P = xcclActiveGroupCounter_ > 0; @@ -680,7 +684,10 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( nanCheck &= enableNanCheck_; auto device = inputs[0].device(); const auto key = std::to_string(device.index()); - auto comm = getXCCLComm(key, device, opType); + std::shared_ptr comm = getXCCLComm(key); + if (comm == nullptr) { + comm = initXCCLComm(key, device, opType); + } if (!coalescing_state_) { seqCollective_++; @@ -846,7 +853,10 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( } op_id_++; - auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + std::shared_ptr comm = getXCCLComm(key); + if (comm == nullptr) { + comm = initXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + } if (coalescing_state_ & CoalActive) { if ((coalescing_state_ & CoalP2P) == 0) { diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 632e05cf2..936dc66f8 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -166,7 +166,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr endCoalescing(OpType optype); - std::shared_ptr getXCCLComm( + std::shared_ptr getXCCLComm(const std::string& deviceKey); + + std::shared_ptr initXCCLComm( const std::string& deviceKey, at::Device& device, OpType opType,