diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index ecef78923..a8cf0909e 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -252,25 +252,39 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL( uint64_t seq, bool isP2P, const char* profilingTitle, - const std::optional>& inputs) + const std::optional>& inputs, + bool enableTiming, + bool xpuEventCacheEnabled) : Work(rank, opType, profilingTitle, inputs), device_(device), workStartTime_(std::chrono::steady_clock::now()), seq_(seq), - isP2P_(isP2P) { - xcclEndEvent_ = std::make_shared(); + isP2P_(isP2P), + timingEnabled_(enableTiming) { + if (xpuEventCacheEnabled) { + xcclStartEvent_ = enableTiming + ? XPUEventCache::get(device.index())->create(enableTiming) + : nullptr; + xcclEndEvent_ = XPUEventCache::get(device.index())->create(enableTiming); + } else { + xcclStartEvent_ = + enableTiming ? std::make_shared(1) : nullptr; + xcclEndEvent_ = std::make_shared(enableTiming ? 1 : 0); + } stashed_for_allocator_safety_ = std::make_shared(); } ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) : Work(w.rank_, w.opType_), device_(w.device_), + xcclStartEvent_(w.xcclStartEvent_), xcclEndEvent_(w.xcclEndEvent_), blockingWait_(w.blockingWait_), workStartTime_(w.workStartTime_), seq_(w.seq_), isP2P_(w.isP2P_), - stashed_for_allocator_safety_(w.stashed_for_allocator_safety_) {} + stashed_for_allocator_safety_(w.stashed_for_allocator_safety_), + timingEnabled_(w.timingEnabled_) {} ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; @@ -369,7 +383,9 @@ ProcessGroupXCCL::ProcessGroupXCCL( local_id_(process_group_id++) { logPrefix_ = createLogPrefix(); blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + xpuEventCacheEnabled_.store(getCvarBool(TORCH_XCCL_XPU_EVENT_CACHE, true)); traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000); + enableTiming_.store(getCvarBool(TORCH_XCCL_ENABLE_TIMING, false)); this->setGroupUid(options_->group_name); // In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just @@ -391,9 +407,11 @@ ProcessGroupXCCL::ProcessGroupXCCL( LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: " << "XCCL version: " << XcclVersion + << ", TORCH_XCCL_ENABLE_TIMING: " << enableTiming_.load() << ", TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_ << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug - << ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_; + << ", TORCH_XCCL_NAN_CHECK: " << enableNanCheck_ + << ", TORCH_XCCL_XPU_EVENT_CACHE: " << xpuEventCacheEnabled_; getGlobalRankStartAndStride( options_->global_ranks_in_group, @@ -453,6 +471,10 @@ uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() { return seqCollective_; } +void ProcessGroupXCCL::enableCollectivesTiming() { + enableTiming_.store(true); +} + void ProcessGroupXCCL::setEnableNanCheck(bool enableNanCheck) { enableNanCheck_ = enableNanCheck; } @@ -474,7 +496,9 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( isP2P, profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) - : std::nullopt); + : std::nullopt, + enableTiming_.load(), + xpuEventCacheEnabled_.load()); if (record) { r->trace_id_ = FlightRecorderXCCL::get()->record( @@ -486,7 +510,7 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( profilingTitle ? profilingTitle : "", inputs, outputs, - nullptr, + r->xcclStartEvent_.get(), r->xcclEndEvent_.get(), options_->timeout, pgStatus_, @@ -495,6 +519,17 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( return r; } +float ProcessGroupXCCL::WorkXCCL::getDuration() const { + TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled"); + TORCH_CHECK( + xcclStartEvent_, + "getDuration only works if xcclStartEvents_ is populated, true if timing enabled"); + TORCH_CHECK( + xcclEndEvent_, + "getDuration only works if xcclEndEvents_ is populated, which should always be true"); + return xcclStartEvent_->elapsed_time(*xcclEndEvent_); +} + std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, at::Device& device, @@ -643,6 +678,10 @@ c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { work->stashed_for_allocator_safety_->stash(coalescedTensors_); + if (work->timingEnabled_) { + work->xcclStartEvent_->record(stream); + } + groupEnd(); work->xcclEndEvent_->record(stream); @@ -773,6 +812,10 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( pre(stream, work); + if (work->timingEnabled_ && !coalescing_state_) { + work->xcclStartEvent_->record(stream); + } + for (const auto i : c10::irange(inputs.size())) { fn(inputs[i], outputs[i], *comm, stream, *cclstream); } @@ -810,12 +853,14 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return asyncOp ? work : nullptr; } -template +template c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( at::Tensor& tensor, Fn fn, int peer, OpType opType, + PreProcess pre, + PostProcess post, const char* profilingTitle) { auto device = tensor.device(); std::string key; @@ -865,13 +910,9 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( auto cclstream = xcclStreamsMap_.at(key).second; syncStream(device, xcclEventsMap_[key], stream); - if (enableNanCheck_ && opType == OpType::SEND) { - checkForNan(tensor, stream); - } - + c10::intrusive_ptr work; if (!coalescing_state_) { - auto work = - initWork(device, rank_, opType, true, profilingTitle, {tensor}, {}); + work = initWork(device, rank_, opType, true, profilingTitle, {tensor}, {}); work->outputs_ = std::make_shared>(); work->outputs_->push_back(tensor); @@ -884,37 +925,12 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( profilingTitle, {tensor}, {tensor}, - nullptr, + work->xcclStartEvent_.get(), work->xcclEndEvent_.get(), options_->timeout, pgStatus_, true); - c10::OptionalDeviceGuard gpuGuard(device); - - c10::xpu::XPUCachingAllocator::recordStream( - tensor.storage().data_ptr(), stream); - - fn(tensor, *comm, stream, cclstream, p2pTargetRank); - - work->xcclEndEvent_->record(stream); - work->blockingWait_ = blockingWait_; - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - auto id = work->trace_id_; - work->future_->addCallback( - [id](at::ivalue::Future&) { - FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false); - }, - /*use_future*/ false); - - work->numelIn_ = work->numelOut_ = tensor.numel(); - setEnqueuedPgStatus(work); - return work; } else { FlightRecorderXCCL::get()->record( local_id_, @@ -930,15 +946,52 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( options_->timeout, pgStatus_, true); - c10::OptionalDeviceGuard gpuGuard(device); + } + + if (enableNanCheck_ && opType == OpType::SEND) { + checkForNan(tensor, stream); + } + if (!coalescing_state_) { + if (work->timingEnabled_) { + work->xcclStartEvent_->record(stream); + } - c10::xpu::XPUCachingAllocator::recordStream( - tensor.storage().data_ptr(), stream); + pre(stream, work); + } + c10::OptionalDeviceGuard gpuGuard(device); - fn(tensor, *comm, stream, cclstream, p2pTargetRank); + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); - return nullptr; + ccl::group_start(); + fn(tensor, *comm, stream, cclstream, p2pTargetRank); + ccl::group_end(); + + if (!coalescing_state_) { + post(stream); + + work->xcclEndEvent_->record(stream); + work->blockingWait_ = blockingWait_; + work->numelIn_ = work->numelOut_ = tensor.numel(); + { + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + } + + auto id = work->trace_id_; + work->future_->addCallback( + [id](at::ivalue::Future&) { + FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false); + }, + /*use_future*/ false); + setEnqueuedPgStatus(work); } + + return work; } c10::intrusive_ptr ProcessGroupXCCL::send( @@ -2043,8 +2096,8 @@ c10::DeviceIndex ProcessGroupXCCL::guessDeviceId() const { } else if (!usedDeviceIdxs_.empty()) { return *usedDeviceIdxs_.begin(); } - int devIdx = - static_cast(globalRank() % at::detail::getXPUHooks().getNumGPUs()); + int devIdx = static_cast( + globalRank() % at::detail::getXPUHooks().getNumGPUs()); LOG(WARNING) << logPrefix() << c10::str( diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 632e05cf2..b0ab050c9 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace c10d { static std::vector TORCH_XCCL_HIGH_PRIORITY = { @@ -35,6 +36,12 @@ static std::vector TORCH_XCCL_COORD_CHECK_MILSEC = { "TORCH_XCCL_COORD_CHECK_MILSEC", "XCCL_COORD_CHECK_MILSEC"}; +static std::vector TORCH_XCCL_XPU_EVENT_CACHE = { + "TORCH_XCCL_XPU_EVENT_CACHE"}; + +static std::vector TORCH_XCCL_ENABLE_TIMING = { + "TORCH_XCCL_ENABLE_TIMING"}; + using xcclComm_t = ccl::communicator; static std::vector TORCH_XCCL_NAN_CHECK = {"TORCH_XCCL_NAN_CHECK"}; @@ -73,7 +80,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { uint64_t seq, bool isP2P, const char* profilingTitle = nullptr, - const std::optional>& inputs = std::nullopt); + const std::optional>& inputs = std::nullopt, + bool enableTiming = false, + bool xpuEventCacheEnabled = false); WorkXCCL(const WorkXCCL& w); ~WorkXCCL() override; @@ -87,6 +96,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { void synchronizeStream(); + float getDuration() const override; + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; c10::intrusive_ptr getFuture() override { @@ -103,6 +114,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: at::Device device_; + std::shared_ptr xcclStartEvent_; std::shared_ptr xcclEndEvent_; bool isBarrierOp_{false}; bool blockingWait_{false}; @@ -117,6 +129,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::shared_ptr> outputs_; std::shared_ptr stashed_for_allocator_safety_; c10::intrusive_ptr future_; + bool timingEnabled_; friend class ProcessGroupXCCL; }; @@ -312,7 +325,27 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, int peer, OpType opType, - const char* profilingTitle = nullptr); + const char* profilingTitle) { + return pointToPoint( + tensor, + fn, + peer, + opType, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [](at::xpu::XPUStream&) {}, + profilingTitle); + } + + template + c10::intrusive_ptr pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + PreProcess pre, + PostProcess post, + const char* profilingTitle); c10::intrusive_ptr allreduce_impl( at::Tensor& tensor, @@ -419,6 +452,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { uint64_t getSequenceNumberForGroup() override; + void enableCollectivesTiming() override; + std::string createLogPrefix() const; const std::string& logPrefix() const; @@ -440,6 +475,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr store_; uint64_t xcclCommCounter_{0}; std::mutex mutex_; + std::atomic xpuEventCacheEnabled_; std::set usedDeviceIdxs_; int coalescing_state_ = 0; at::Device coalescedDevice_ = at::Device("xpu"); @@ -447,6 +483,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { bool coalescedAsync_; TensorShelf coalescedTensors_; bool blockingWait_ = false; + std::atomic enableTiming_; static thread_local uint64_t xcclActiveGroupCounter_; uint64_t seqCollective_{0}; uint64_t seqP2P_{0}; diff --git a/src/xccl/XPUEventCache.cpp b/src/xccl/XPUEventCache.cpp new file mode 100644 index 000000000..54682a449 --- /dev/null +++ b/src/xccl/XPUEventCache.cpp @@ -0,0 +1,43 @@ +#include +#include +#include + +namespace c10d { + +XPUEventCache::XPUEventCache() = default; + +std::shared_ptr XPUEventCache::create(bool timing) { + auto deleter = [cache = shared_from_this(), + timing](at::xpu::XPUEvent* event) { + std::lock_guard lock(cache->cacheMutex_); + + cache->eventsArray_[timing ? 1 : 0].push_back(event); + }; + at::xpu::XPUEvent* event = nullptr; + { + std::lock_guard lock(cacheMutex_); + auto& events = eventsArray_[timing ? 1 : 0]; + // If we still have events in the cache, we reuse it. Otherwise, we create a + // new one. + if (!events.empty()) { + event = events.front(); + events.pop_front(); + } else { + event = new at::xpu::XPUEvent(timing ? 1 : 0); + } + } + return std::shared_ptr(event, std::move(deleter)); +} + +std::shared_ptr XPUEventCache::get(at::DeviceIndex device) { + static thread_local std::map> + cacheDeviceMap; + // Check if device has already been in the map, if not, add a new entry + auto it = cacheDeviceMap.find(device); + if (it == cacheDeviceMap.end()) { + cacheDeviceMap.emplace(device, std::make_shared()); + } + return cacheDeviceMap[device]; +} + +} // namespace c10d diff --git a/src/xccl/XPUEventCache.hpp b/src/xccl/XPUEventCache.hpp new file mode 100644 index 000000000..864e93eda --- /dev/null +++ b/src/xccl/XPUEventCache.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace c10d { + +class TORCH_API XPUEventCache + : public std::enable_shared_from_this { + public: + XPUEventCache(); + std::shared_ptr create(bool timing); + static std::shared_ptr get(at::DeviceIndex device); + + private: + std::mutex cacheMutex_; + // NOTE: We intentionally store raw pointers so that + // we do not attempt to destroy the event objects on process exit, + // because cuda may be gone. + std::array, 2> + eventsArray_; // 0 for timing=false, 1 for timing=true +}; + +} // namespace c10d