Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 100 additions & 47 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,25 +252,39 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(
uint64_t seq,
bool isP2P,
const char* profilingTitle,
const std::optional<std::vector<at::Tensor>>& inputs)
const std::optional<std::vector<at::Tensor>>& inputs,
bool enableTiming,
bool xpuEventCacheEnabled)
: Work(rank, opType, profilingTitle, inputs),
device_(device),
workStartTime_(std::chrono::steady_clock::now()),
seq_(seq),
isP2P_(isP2P) {
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>();
isP2P_(isP2P),
timingEnabled_(enableTiming) {
if (xpuEventCacheEnabled) {
xcclStartEvent_ = enableTiming
? XPUEventCache::get(device.index())->create(enableTiming)
: nullptr;
xcclEndEvent_ = XPUEventCache::get(device.index())->create(enableTiming);
} else {
xcclStartEvent_ =
enableTiming ? std::make_shared<at::xpu::XPUEvent>(1) : nullptr;
xcclEndEvent_ = std::make_shared<at::xpu::XPUEvent>(enableTiming ? 1 : 0);
}
stashed_for_allocator_safety_ = std::make_shared<TensorShelf>();
}

ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)
: Work(w.rank_, w.opType_),
device_(w.device_),
xcclStartEvent_(w.xcclStartEvent_),
xcclEndEvent_(w.xcclEndEvent_),
blockingWait_(w.blockingWait_),
workStartTime_(w.workStartTime_),
seq_(w.seq_),
isP2P_(w.isP2P_),
stashed_for_allocator_safety_(w.stashed_for_allocator_safety_) {}
stashed_for_allocator_safety_(w.stashed_for_allocator_safety_),
timingEnabled_(w.timingEnabled_) {}

ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default;

Expand Down Expand Up @@ -369,7 +383,9 @@ ProcessGroupXCCL::ProcessGroupXCCL(
local_id_(process_group_id++) {
logPrefix_ = createLogPrefix();
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
xpuEventCacheEnabled_.store(getCvarBool(TORCH_XCCL_XPU_EVENT_CACHE, true));
traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000);
enableTiming_.store(getCvarBool(TORCH_XCCL_ENABLE_TIMING, false));

this->setGroupUid(options_->group_name);
// In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just
Expand All @@ -391,9 +407,11 @@ ProcessGroupXCCL::ProcessGroupXCCL(

LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: "
<< "XCCL version: " << XcclVersion
<< ", TORCH_XCCL_ENABLE_TIMING: " << enableTiming_.load()
<< ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
<< ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_;
<< ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_
<< ", TORCH_XCCL_XPU_EVENT_CACHE: " << xpuEventCacheEnabled_;

getGlobalRankStartAndStride(
options_->global_ranks_in_group,
Expand Down Expand Up @@ -453,6 +471,10 @@ uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() {
return seqCollective_;
}

void ProcessGroupXCCL::enableCollectivesTiming() {
enableTiming_.store(true);
}

void ProcessGroupXCCL::setEnableNanCheck(bool enableNanCheck) {
enableNanCheck_ = enableNanCheck;
}
Expand All @@ -474,7 +496,9 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
isP2P,
profilingTitle,
profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
: std::nullopt);
: std::nullopt,
enableTiming_.load(),
xpuEventCacheEnabled_.load());

if (record) {
r->trace_id_ = FlightRecorderXCCL::get()->record(
Expand All @@ -486,7 +510,7 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
profilingTitle ? profilingTitle : "",
inputs,
outputs,
nullptr,
r->xcclStartEvent_.get(),
r->xcclEndEvent_.get(),
options_->timeout,
pgStatus_,
Expand All @@ -495,6 +519,17 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
return r;
}

float ProcessGroupXCCL::WorkXCCL::getDuration() const {
TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled");
TORCH_CHECK(
xcclStartEvent_,
"getDuration only works if xcclStartEvents_ is populated, true if timing enabled");
TORCH_CHECK(
xcclEndEvent_,
"getDuration only works if xcclEndEvents_ is populated, which should always be true");
return xcclStartEvent_->elapsed_time(*xcclEndEvent_);
}

std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
const std::string& deviceKey,
at::Device& device,
Expand Down Expand Up @@ -643,6 +678,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {

work->stashed_for_allocator_safety_->stash(coalescedTensors_);

if (work->timingEnabled_) {
work->xcclStartEvent_->record(stream);
}

groupEnd();

work->xcclEndEvent_->record(stream);
Expand Down Expand Up @@ -773,6 +812,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(

pre(stream, work);

if (work->timingEnabled_ && !coalescing_state_) {
work->xcclStartEvent_->record(stream);
}

for (const auto i : c10::irange(inputs.size())) {
fn(inputs[i], outputs[i], *comm, stream, *cclstream);
}
Expand Down Expand Up @@ -810,12 +853,14 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
return asyncOp ? work : nullptr;
}

template <typename Fn>
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
at::Tensor& tensor,
Fn fn,
int peer,
OpType opType,
PreProcess pre,
PostProcess post,
const char* profilingTitle) {
auto device = tensor.device();
std::string key;
Expand Down Expand Up @@ -865,13 +910,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
auto cclstream = xcclStreamsMap_.at(key).second;
syncStream(device, xcclEventsMap_[key], stream);

if (enableNanCheck_ && opType == OpType::SEND) {
checkForNan(tensor, stream);
}

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
if (!coalescing_state_) {
auto work =
initWork(device, rank_, opType, true, profilingTitle, {tensor}, {});
work = initWork(device, rank_, opType, true, profilingTitle, {tensor}, {});
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
work->outputs_->push_back(tensor);

Expand All @@ -884,37 +925,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
profilingTitle,
{tensor},
{tensor},
nullptr,
work->xcclStartEvent_.get(),
work->xcclEndEvent_.get(),
options_->timeout,
pgStatus_,
true);

c10::OptionalDeviceGuard gpuGuard(device);

c10::xpu::XPUCachingAllocator::recordStream(
tensor.storage().data_ptr(), stream);

fn(tensor, *comm, stream, cclstream, p2pTargetRank);

work->xcclEndEvent_->record(stream);
work->blockingWait_ = blockingWait_;
std::vector<c10::Stream> streams = {stream.unwrap()};
c10::MultiStreamGuard streamGuard(streams);
std::vector<at::Device> devices{device};
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
auto id = work->trace_id_;
work->future_->addCallback(
[id](at::ivalue::Future&) {
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
},
/*use_future*/ false);

work->numelIn_ = work->numelOut_ = tensor.numel();
setEnqueuedPgStatus(work);
return work;
} else {
FlightRecorderXCCL::get()->record(
local_id_,
Expand All @@ -930,15 +946,52 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
options_->timeout,
pgStatus_,
true);
c10::OptionalDeviceGuard gpuGuard(device);
}

if (enableNanCheck_ && opType == OpType::SEND) {
checkForNan(tensor, stream);
}
if (!coalescing_state_) {
if (work->timingEnabled_) {
work->xcclStartEvent_->record(stream);
}

c10::xpu::XPUCachingAllocator::recordStream(
tensor.storage().data_ptr(), stream);
pre(stream, work);
}
c10::OptionalDeviceGuard gpuGuard(device);

fn(tensor, *comm, stream, cclstream, p2pTargetRank);
c10::xpu::XPUCachingAllocator::recordStream(
tensor.storage().data_ptr(), stream);

return nullptr;
ccl::group_start();
fn(tensor, *comm, stream, cclstream, p2pTargetRank);
ccl::group_end();

if (!coalescing_state_) {
post(stream);

work->xcclEndEvent_->record(stream);
work->blockingWait_ = blockingWait_;
work->numelIn_ = work->numelOut_ = tensor.numel();
{
std::vector<c10::Stream> streams = {stream.unwrap()};
c10::MultiStreamGuard streamGuard(streams);
std::vector<at::Device> devices{device};
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
}

auto id = work->trace_id_;
work->future_->addCallback(
[id](at::ivalue::Future&) {
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
},
/*use_future*/ false);
setEnqueuedPgStatus(work);
}

return work;
}

c10::intrusive_ptr<Work> ProcessGroupXCCL::send(
Expand Down Expand Up @@ -2043,8 +2096,8 @@ c10::DeviceIndex ProcessGroupXCCL::guessDeviceId() const {
} else if (!usedDeviceIdxs_.empty()) {
return *usedDeviceIdxs_.begin();
}
int devIdx =
static_cast<int16_t>(globalRank() % at::detail::getXPUHooks().getNumGPUs());
int devIdx = static_cast<int16_t>(
globalRank() % at::detail::getXPUHooks().getNumGPUs());
LOG(WARNING)
<< logPrefix()
<< c10::str(
Expand Down
41 changes: 39 additions & 2 deletions src/xccl/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <torch/csrc/distributed/c10d/TraceUtils.h>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <xccl/ProcessGroupXCCLMonitor.hpp>
#include <xccl/XPUEventCache.hpp>
namespace c10d {

static std::vector<std::string> TORCH_XCCL_HIGH_PRIORITY = {
Expand All @@ -35,6 +36,12 @@ static std::vector<std::string> TORCH_XCCL_COORD_CHECK_MILSEC = {
"TORCH_XCCL_COORD_CHECK_MILSEC",
"XCCL_COORD_CHECK_MILSEC"};

static std::vector<std::string> TORCH_XCCL_XPU_EVENT_CACHE = {
"TORCH_XCCL_XPU_EVENT_CACHE"};

static std::vector<std::string> TORCH_XCCL_ENABLE_TIMING = {
"TORCH_XCCL_ENABLE_TIMING"};

using xcclComm_t = ccl::communicator;

static std::vector<std::string> TORCH_XCCL_NAN_CHECK = {"TORCH_XCCL_NAN_CHECK"};
Expand Down Expand Up @@ -73,7 +80,9 @@ class TORCH_API ProcessGroupXCCL : public Backend {
uint64_t seq,
bool isP2P,
const char* profilingTitle = nullptr,
const std::optional<std::vector<at::Tensor>>& inputs = std::nullopt);
const std::optional<std::vector<at::Tensor>>& inputs = std::nullopt,
bool enableTiming = false,
bool xpuEventCacheEnabled = false);
WorkXCCL(const WorkXCCL& w);
~WorkXCCL() override;

Expand All @@ -87,6 +96,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {

void synchronizeStream();

float getDuration() const override;

bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;

c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
Expand All @@ -103,6 +114,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {

protected:
at::Device device_;
std::shared_ptr<at::xpu::XPUEvent> xcclStartEvent_;
std::shared_ptr<at::xpu::XPUEvent> xcclEndEvent_;
bool isBarrierOp_{false};
bool blockingWait_{false};
Expand All @@ -117,6 +129,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
std::shared_ptr<std::vector<at::Tensor>> outputs_;
std::shared_ptr<TensorShelf> stashed_for_allocator_safety_;
c10::intrusive_ptr<at::ivalue::Future> future_;
bool timingEnabled_;
friend class ProcessGroupXCCL;
};

Expand Down Expand Up @@ -312,7 +325,27 @@ class TORCH_API ProcessGroupXCCL : public Backend {
Fn fn,
int peer,
OpType opType,
const char* profilingTitle = nullptr);
const char* profilingTitle) {
return pointToPoint(
tensor,
fn,
peer,
opType,
[](at::xpu::XPUStream&,
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL>& work) {},
[](at::xpu::XPUStream&) {},
profilingTitle);
}

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> pointToPoint(
at::Tensor& tensor,
Fn fn,
int peer,
OpType opType,
PreProcess pre,
PostProcess post,
const char* profilingTitle);

c10::intrusive_ptr<Work> allreduce_impl(
at::Tensor& tensor,
Expand Down Expand Up @@ -419,6 +452,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {

uint64_t getSequenceNumberForGroup() override;

void enableCollectivesTiming() override;

std::string createLogPrefix() const;

const std::string& logPrefix() const;
Expand All @@ -440,13 +475,15 @@ class TORCH_API ProcessGroupXCCL : public Backend {
c10::intrusive_ptr<Store> store_;
uint64_t xcclCommCounter_{0};
std::mutex mutex_;
std::atomic<bool> xpuEventCacheEnabled_;
std::set<int> usedDeviceIdxs_;
int coalescing_state_ = 0;
at::Device coalescedDevice_ = at::Device("xpu");
std::shared_ptr<xcclComm_t> coalescedComm_ = nullptr;
bool coalescedAsync_;
TensorShelf coalescedTensors_;
bool blockingWait_ = false;
std::atomic<bool> enableTiming_;
static thread_local uint64_t xcclActiveGroupCounter_;
uint64_t seqCollective_{0};
uint64_t seqP2P_{0};
Expand Down
Loading