diff --git a/dipu/tests/python/unittests/test_generator.py b/dipu/tests/python/unittests/test_generator.py index 379e83e39..631b57749 100644 --- a/dipu/tests/python/unittests/test_generator.py +++ b/dipu/tests/python/unittests/test_generator.py @@ -2,7 +2,11 @@ import torch import torch_dipu from torch_dipu import diputype -from torch_dipu.testing._internal.common_utils import TestCase, run_tests +from torch_dipu.testing._internal.common_utils import ( + TestCase, + run_tests, + onlyOn, +) class TestGenerator(TestCase): @@ -20,13 +24,13 @@ def test_python_api(self): torch.cuda.manual_seed(i) state = torch.cuda.get_rng_state(0) - new_state = torch.ones_like(state) + new_state = torch.ones_like(state) * 4 torch.cuda.set_rng_state(new_state, 0) current_state = torch.cuda.get_rng_state(0) self.assertTrue( torch.allclose( current_state, - torch.tensor(1, device=current_state.device, dtype=current_state.dtype), + torch.tensor(4, device=current_state.device, dtype=current_state.dtype), ) ) @@ -194,6 +198,23 @@ def test_default_generators(self): torch.cuda.default_generators[0].manual_seed(1) self.assertEqual(torch.cuda.default_generators[0].initial_seed(), 1) + @onlyOn("CUDA") + def test_cuda_generator(self): + state = torch.cuda.get_rng_state(0) + state[-16] = 4 + state[-15:-8] = 0 + state[-8:] = 0 + torch.cuda.set_rng_state(state) + self.assertEqual(torch.cuda.initial_seed(), 4) + + # invalid offset, offset must be a multiple of 4 + state[-8:] = 1 + try: + torch.cuda.set_rng_state(state) + self.assertTrue(False, "should not go here") + except Exception as ex: + self.assertIn("offset must be a multiple of 4", ex.args[0]) + if __name__ == "__main__": run_tests() diff --git a/dipu/third_party/DIOPI b/dipu/third_party/DIOPI index 65930a539..02f03c6ab 160000 --- a/dipu/third_party/DIOPI +++ b/dipu/third_party/DIOPI @@ -1 +1 @@ -Subproject commit 65930a539938b692a84ba77027e91686b3d2516d +Subproject commit 02f03c6abb20aa39d1d978436a53a2e4ec242d65 diff --git a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp index 4574f2b4f..f6ef6bbf0 100644 --- a/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp +++ b/dipu/torch_dipu/csrc_dipu/binding/ExportRT.cpp @@ -205,7 +205,8 @@ void exportEvent(py::module& m) { py::arg("enable_timing") = false, py::arg("blocking") = false, py::arg("interprocess") = false) .def("record", py::overload_cast<>(&DIPUEvent::record), "record event") - .def("record", py::overload_cast(&DIPUEvent::record), + .def("record", py::overload_cast(&DIPUEvent::record), + py::arg("stream"), py::arg("use_pool") = true, "record event on stream") .def("elapsed_time", &dipu::DIPUEvent::elapsed_time) .def("synchronize", @@ -249,6 +250,8 @@ void exportCommunicator(py::module& m) { return kBackendDefaultTimeout; }); + m.def("dump_info", dumpInfo); + // py::object mdist = py::module::import("torch.distributed"); // py::object register_backend = // mdist.attr("Backend").attr("register_backend"); The first parameter is the diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h index 1af20b840..3165215dd 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUEvent.h @@ -20,6 +20,7 @@ class DIPU_API DIPUEvent { deviceEvent_t event_{nullptr}; c10::DeviceIndex device_index_{-1}; c10::StreamId last_recorded_stream_id_{-1}; + bool use_pool_{true}; public: DIPUEvent(const DIPUEvent&) = delete; @@ -29,7 +30,8 @@ class DIPU_API DIPUEvent { constexpr DIPUEvent(DIPUEvent&& other) noexcept : event_(other.event_), device_index_(other.device_index_), - last_recorded_stream_id_(other.last_recorded_stream_id_) { + last_recorded_stream_id_(other.last_recorded_stream_id_), + use_pool_(other.use_pool_) { other.unsafe_reset(); } @@ -39,6 +41,7 @@ class DIPU_API DIPUEvent { event_ = other.event_; device_index_ = other.device_index_; last_recorded_stream_id_ = other.last_recorded_stream_id_; + use_pool_ = other.use_pool_; other.unsafe_reset(); } return *this; @@ -76,8 +79,9 @@ class DIPU_API DIPUEvent { void record() { record(getCurrentDIPUStream()); } - void record(const DIPUStream& stream) { + void record(const DIPUStream& stream, bool use_pool = true) { if (!initialized()) { + use_pool_ = use_pool; create_event(stream.device_index()); } @@ -124,14 +128,23 @@ class DIPU_API DIPUEvent { void create_event(c10::DeviceIndex device_index) { device_index_ = device_index; DIPUGuard guard(device_index_); - devproxy::createEvent(&event_); + if(use_pool_) { + devproxy::createEvent(&event_); + } else { + devapis::createEvent(&event_); + } } void release_event() { if (initialized()) { DIPUGuard guard(device_index_); - devproxy::destroyEvent(event_); + if(use_pool_) { + devproxy::destroyEvent(event_); + } else { + devapis::destroyEvent(event_); + } event_ = nullptr; + use_pool_ = true; } } }; diff --git a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUGeneratorImpl.cpp b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUGeneratorImpl.cpp index 488fd7c7b..124dfe776 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUGeneratorImpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/core/DIPUGeneratorImpl.cpp @@ -89,7 +89,7 @@ DIPUGeneratorImpl::DIPUGeneratorImpl(at::DeviceIndex device_index) */ void DIPUGeneratorImpl::set_current_seed(uint64_t seed) { seed_ = seed; - offset_ = 0; + set_offset(0); state_need_reset_ = true; } diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp index ea958b12e..8d06f0780 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.cpp @@ -1,6 +1,8 @@ // Copyright (c) 2023, DeepLink. #include "ProcessGroupDICL.h" +#include +#include #include #include @@ -120,6 +122,70 @@ void checkGatherScatterRootRank( } // anonymous namespace +// start WorkStore + +class WorkStore { + struct WorkInfo { + DIPUEvent startEvent_; + DIPUEvent endEvent_; + int rank_; + int comm_size_; + }; + + public: + void setUid(const std::vector& uidVec) { uniqueidVec_ = uidVec; } + + size_t recordStart(const DIPUStream& stream, int rank, int comm_size) { + std::lock_guard lock(mtx_); + info_vec_.push_back(WorkInfo()); + size_t index = info_vec_.size() - 1; + info_vec_[index].startEvent_.record(stream, false); + info_vec_[index].rank_ = rank; + info_vec_[index].comm_size_ = comm_size; + + return index; + } + + void recordEnd(const DIPUStream& stream, size_t index) { + std::lock_guard lock(mtx_); + info_vec_[index].endEvent_.record(stream, false); + } + + void dump(std::string& path) { + for (auto& wi : info_vec_) { + wi.endEvent_.synchronize(); + float duration = wi.startEvent_.elapsed_time(wi.endEvent_); + std::ostringstream oss; + oss << "PG uniqueId = "; + for (int i = 0; i < 32; ++i) { + oss << static_cast(uniqueidVec_[i]); + } + oss << ", comm_size = " << wi.comm_size_ << ", duration = " << duration + << std::endl; + std::string filePath = path + "/rank_" + std::to_string(wi.rank_); + std::ofstream outFile(filePath, std::ios::app); + outFile << oss.str(); + } + + info_vec_.clear(); + } + + private: + std::vector info_vec_; + std::mutex mtx_; + std::vector uniqueidVec_; +}; + +// end WorkStore + +std::vector> global_stores; + +void dumpInfo(std::string& path) { + for (auto p : global_stores) { + p->dump(path); + } +} + // start WorkDICL // currently DICL do not support error check @@ -196,7 +262,10 @@ ProcessGroupDICL::WorkDICL::getFuture() { ProcessGroupDICL::ProcessGroupDICL(const c10::intrusive_ptr& store, int rank, int size) - : c10d::Backend(rank, size), store_(store) { + : c10d::Backend(rank, size), + store_(store), + pWstore_(std::make_shared()) { + global_stores.push_back(pWstore_); char* blockingWait = getenv(DICL_BLOCKING_WAIT); try { if (blockingWait != nullptr) { @@ -238,6 +307,7 @@ void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId, auto vec = std::vector(reinterpret_cast(uniqueId), reinterpret_cast(uniqueId) + devapis::DICL_UNIQUE_ID_BYTES_SIZE); + pWstore_->setUid(vec); store_->set(storeKey, vec); } else { auto vec = store_->get(storeKey); @@ -246,6 +316,7 @@ void ProcessGroupDICL::broadcastUniqueID(commUniqueId* uniqueId, "Unexpected DICL unique ID length received " "from the store"); } + pWstore_->setUid(vec); std::memcpy(uniqueId, vec.data(), vec.size()); } } @@ -442,6 +513,13 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( auto work = c10::make_intrusive( diclComms, blockingWait_, opTimeout_); + size_t eventIndex; + if (opType == OpType::ALLREDUCE) { + eventIndex = + pWstore_->recordStart(diclComms[0]->diclStream_, this->rank_, + inputs[0].element_size() * inputs[0].numel()); + } + OptionalDIPUGuard dipuGuard; pre(diclComms); @@ -466,6 +544,11 @@ c10::intrusive_ptr ProcessGroupDICL::doComm( } post(diclComms); + + if (opType == OpType::ALLREDUCE) { + pWstore_->recordEnd(diclComms[0]->diclStream_, eventIndex); + } + work->record(); work->outputs_ = std::make_shared>(outputs); diff --git a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h index d5ba9da1e..b2f3171ba 100644 --- a/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h +++ b/dipu/torch_dipu/csrc_dipu/runtime/distributed/ProcessGroupDICL.h @@ -2,6 +2,7 @@ #pragma once #include +#include #include #include #include @@ -40,6 +41,10 @@ using c10d::Work; constexpr const char* DICL_BLOCKING_WAIT = "DICL_BLOCKING_WAIT"; constexpr int64_t diclSyncBusyWaitMillis = 30; +void dumpInfo(std::string& path); + +class WorkStore; + /** * ProcessGroupDICL implements DICLbindings for c10d. * @@ -310,6 +315,8 @@ class DIPU_API ProcessGroupDICL : public Backend { bool blockingWait_ = false; std::chrono::milliseconds opTimeout_ = kBackendDefaultTimeout; + + std::shared_ptr pWstore_; }; namespace dicl_hook { diff --git a/dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp index 6ef338bc3..03d8d5a75 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/cuda/CudaGeneratorImpl.cpp @@ -39,11 +39,12 @@ class CUDAGeneratorImpl : public dipu::DIPUGeneratorImpl { #else auto new_rng_state = state.data_dtype_initialized(); #endif - memcpy(&input_seed, new_rng_state, seed_size); + memcpy(&input_seed, new_rng_state + states_size, seed_size); this->set_current_seed(input_seed); int64_t philox_offset = 0; if (!no_philox_seed) { - memcpy(&philox_offset, new_rng_state + seed_size, offset_size); + memcpy(&philox_offset, new_rng_state + states_size + seed_size, + offset_size); } this->set_offset(static_cast(philox_offset)); @@ -71,6 +72,11 @@ class CUDAGeneratorImpl : public dipu::DIPUGeneratorImpl { state_need_reset_ = false; } } + + void set_offset(uint64_t offset) override { + TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4"); + DIPUGeneratorImpl::set_offset(offset); + } }; // NOLINTNEXTLINE(readability-const-return-type) diff --git a/dipu/torch_dipu/dipu/distributed.py b/dipu/torch_dipu/dipu/distributed.py index 73e0cfbac..feb5861e4 100644 --- a/dipu/torch_dipu/dipu/distributed.py +++ b/dipu/torch_dipu/dipu/distributed.py @@ -1,4 +1,5 @@ from datetime import timedelta +import os import torch from torch import distributed as dist @@ -113,6 +114,10 @@ def _wrap_new_group( return _raw_new_group(ranks, timeout, backend, pg_options) +def _wrap_dump_info(path): + _C.dump_info(path) + + def apply_dist_patch(): dist.get_backend = _wrap_get_backend dist.init_process_group = _wrap_init_process_groups @@ -123,3 +128,5 @@ def apply_dist_patch(): if dipu.get_dipu_torch_version() == dipu.torch_ver_200: dist.new_group = _wrap_new_group + + dist.dump_info = _wrap_dump_info