Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,17 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
}

std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
const std::string& deviceKey) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = devXCCLCommMap_.find(deviceKey);
if (it != devXCCLCommMap_.end()) {
// Reuse the cached communicator if there is one.
return it->second;
}
return nullptr;
}

std::shared_ptr<xcclComm_t> ProcessGroupXCCL::initXCCLComm(
const std::string& deviceKey,
at::Device& device,
OpType opType,
Expand All @@ -505,13 +516,6 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(

usedDeviceIdxs_.insert(device.index());

{
std::lock_guard<std::mutex> lock(mutex_);
if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) {
return devXCCLCommMap_[deviceKey];
}
}

std::shared_ptr<xcclComm_t> XCCLComm;

bool batchP2P = xcclActiveGroupCounter_ > 0;
Expand Down Expand Up @@ -669,7 +673,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
nanCheck &= enableNanCheck_;
auto device = inputs[0].device();
const auto key = std::to_string(device.index());
auto comm = getXCCLComm(key, device, opType);
std::shared_ptr<xcclComm_t> comm = getXCCLComm(key);
if (comm == nullptr) {
comm = initXCCLComm(key, device, opType);
}

if (!coalescing_state_) {
seqCollective_++;
Expand Down Expand Up @@ -835,7 +842,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
}

op_id_++;
auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf);
std::shared_ptr<xcclComm_t> comm = getXCCLComm(key);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like no difference with previous code, then why to separate as two APIs/steps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for align nccl, To facilitate future feature integration, we can first align this part.

if (comm == nullptr) {
comm = initXCCLComm(key, device, opType, p2pRank, isSendRecvSelf);
}

if (coalescing_state_ & CoalActive) {
if ((coalescing_state_ & CoalP2P) == 0) {
Expand Down
4 changes: 3 additions & 1 deletion src/xccl/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ class TORCH_API ProcessGroupXCCL : public Backend {

c10::intrusive_ptr<Work> endCoalescing(OpType optype);

std::shared_ptr<xcclComm_t> getXCCLComm(
std::shared_ptr<xcclComm_t> getXCCLComm(const std::string& deviceKey);

std::shared_ptr<xcclComm_t> initXCCLComm(
const std::string& deviceKey,
at::Device& device,
OpType opType,
Expand Down
Loading