diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index 0c938dabe2..ac5fb93992 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -20,7 +20,6 @@ #endif // defined(BTAS_IS_USABLE) #endif // defined(BLOCK_SPARSE_GEMM) -#include #include #if !defined(BLOCK_SPARSE_GEMM) #include @@ -30,14 +29,13 @@ #ifdef BSPMM_HAS_LIBINT #include -#include #endif // TA is only usable if MADNESS backend is used #if defined(BSPMM_HAS_TILEDARRAY) && defined(TTG_USE_MADNESS) -# define BSPMM_BUILD_TA_TEST -# include -# include +#define BSPMM_BUILD_TA_TEST +#include +#include #endif #include "ttg.h" @@ -54,12 +52,12 @@ using namespace ttg; // shallow-copy storage using storage_type = btas::mohndle, btas::Handle::shared_ptr>; // deep-copy storage -//using storage_type = btas::varray; -# ifndef BSPMM_BUILD_TA_TEST // TA overloads btas's impl of btas::dot with its own, but must use TA::Range +// using storage_type = btas::varray; +#ifndef BSPMM_BUILD_TA_TEST // TA overloads btas's impl of btas::dot with its own, but must use TA::Range using blk_t = btas::Tensor; -# else +#else using blk_t = btas::Tensor; -# endif +#endif #if defined(TTG_USE_PARSEC) namespace ttg { @@ -296,8 +294,6 @@ class SpMM { long nbphases() const { return plan_->nb_steps(); } - std::pair gemmsperrankperphase() const { return plan_->gemmsperrankperphase(); } - /// Plan: group all GEMMs in blocks of efficient size class Plan { public: @@ -315,21 +311,20 @@ class SpMM { const long comm_threshold_; private: - struct long_tuple_hash : public std::unary_function, std::size_t> { - std::size_t operator()(const std::tuple &k) const { - return static_cast(std::get<0>(k)) | (static_cast(std::get<1>(k)) << 21) | - (static_cast(std::get<2>(k)) << 21); + struct long_pair_hash : public std::unary_function, std::size_t> { + std::size_t operator()(const std::tuple &k) const { + return static_cast(std::get<0>(k)) | (static_cast(std::get<1>(k)) << 32); } }; - using gemmset_t = std::set>; - using step_vector_t = std::vector>; - using step_per_tile_t = std::unordered_map, std::set, long_tuple_hash>; - using bcastset_t = std::set>; + using bcastset_t = std::unordered_set, long_pair_hash>; using comm_plan_t = std::vector>; - using step_t = std::tuple; + using full_comm_plan_t = std::tuple; - const step_t steps_; + const long dim_; + const long mt_, nt_, kt_; + const long mns_, nns_, kns_; + const full_comm_plan_t comm_plan_; public: Plan(const std::vector> &a_rowidx_to_colidx, @@ -349,17 +344,24 @@ class SpMM { , P_(P) , Q_(Q) , lookahead_(lookahead + 1) // users generally understand that a lookahead of 0 still progresses - , steps_(strategy_selector(memory, forced_split)) - , comm_threshold_(3) { - if (tracing()) display_plan(); + , dim_(strategy_selector(memory, forced_split)) + , comm_plan_(regular_cube_strategy(dim_)) + , comm_threshold_(3) + , mt_(mTiles_.size()) + , nt_(nTiles_.size()) + , kt_(kTiles_.size()) + , mns_((mt_ + dim_ - 1) / dim_) + , nns_((nt_ + dim_ - 1) / dim_) + , kns_((kt_ + dim_ - 1) / dim_) { + if (tracing()) display_comm_plan(); } - step_t strategy_selector(size_t memory, long forced_split) const { + long strategy_selector(size_t memory, long forced_split) const { if (0 == forced_split) return active_set_strategy(memory); - return regular_cube_strategy(forced_split); + return forced_split; } - step_t active_set_strategy(size_t memory) const { + long active_set_strategy(size_t memory) const { ActiveSetStrategy st(a_rowidx_to_colidx_, a_colidx_to_rowidx_, b_rowidx_to_colidx_, b_colidx_to_rowidx_, mTiles_, nTiles_, kTiles_, memory); @@ -466,13 +468,10 @@ class SpMM { (double)excess / (double)(3 * cube_dim * cube_dim)); } - return regular_cube_strategy(cube_dim); + return cube_dim; } - step_t regular_cube_strategy(long cube_dim) const { - step_vector_t steps; - step_per_tile_t steps_per_tile_A; - step_per_tile_t steps_per_tile_B; + full_comm_plan_t regular_cube_strategy(long cube_dim) const { comm_plan_t comm_plan_A; comm_plan_t comm_plan_B; auto rank = ttg_default_execution_context().rank(); @@ -492,6 +491,10 @@ class SpMM { ttg::print("On rank ", ttg_default_execution_context().rank(), " Planning with a cube_dim of ", cube_dim, " over a problem of ", mt, "x", nt, "x", kt, " gives a plan of ", mns, "x", nns, "x", kns); + long nnz_in_AB = 0; + for (auto mm = 0l; mm < a_rowidx_to_colidx_.size(); mm++) nnz_in_AB += a_rowidx_to_colidx_[mm].size(); + for (auto mm = 0l; mm < b_rowidx_to_colidx_.size(); mm++) nnz_in_AB += b_rowidx_to_colidx_[mm].size(); + std::vector a_sent; std::vector b_sent; std::vector a_in_comm_step; @@ -502,72 +505,44 @@ class SpMM { b_in_comm_step.resize(P_ * Q_); comm_plan_A.resize(P_ * Q_); comm_plan_B.resize(P_ * Q_); - for (long mm = 0; mm < mns; mm++) { - for (long nn = 0; nn < nns; nn++) { - for (long kk = 0; kk < kns; kk++) { - gemmset_t gemms; - gemmset_t local_gemms; - long nb_local_gemms = 0; - for (long m = mm * cube_dim; m < (mm + 1) * cube_dim && m < mt; m++) { - if (m >= a_rowidx_to_colidx_.size() || a_rowidx_to_colidx_[m].empty()) continue; - for (long k = kk * cube_dim; k < (kk + 1) * cube_dim && k < kt; k++) { - if (k >= b_rowidx_to_colidx_.size() || b_rowidx_to_colidx_[k].empty()) continue; - if (std::find(a_rowidx_to_colidx_[m].begin(), a_rowidx_to_colidx_[m].end(), k) == - a_rowidx_to_colidx_[m].end()) - continue; - for (long n = nn * cube_dim; n < (nn + 1) * cube_dim && n < nt; n++) { - if (n >= b_colidx_to_rowidx_.size() || b_colidx_to_rowidx_[n].empty()) continue; - if (std::find(b_colidx_to_rowidx_[n].begin(), b_colidx_to_rowidx_[n].end(), k) == - b_colidx_to_rowidx_[n].end()) - continue; - auto r = keymap_(Key<2>({m, n})); - if (r == rank) { - local_gemms.insert({m, n, k}); - nb_local_gemms++; - } - gemms.insert({m, n, k}); - auto it = steps_per_tile_A.find(std::make_tuple(r, m, k)); - if (it == steps_per_tile_A.end()) { - std::set f; - f.insert(steps.size()); - steps_per_tile_A.insert({std::make_tuple(r, m, k), f}); - } else { - it->second.insert(steps.size()); - } - - it = steps_per_tile_B.find(std::make_tuple(r, k, n)); - if (it == steps_per_tile_B.end()) { - std::set f; - f.insert(steps.size()); - steps_per_tile_B.insert({std::make_tuple(r, k, n), f}); - } else { - it->second.insert(steps.size()); - } - auto a_rank = keymap_(Key<2>{m, k}); - if (a_sent[a_rank].find({m, k}) == a_sent[a_rank].end()) { - a_sent[a_rank].insert({m, k}); - a_in_comm_step[a_rank].insert(std::make_pair(m, k)); - if (a_in_comm_step[a_rank].size() >= comm_threshold_) { - comm_plan_A[a_rank].push_back(a_in_comm_step[a_rank]); - a_in_comm_step[a_rank].clear(); - } - } - auto b_rank = keymap_(Key<2>{k, n}); - if (b_sent[b_rank].find({k, n}) == b_sent[b_rank].end()) { - b_sent[b_rank].insert({k, n}); - b_in_comm_step[b_rank].insert(std::make_pair(k, n)); - if (b_in_comm_step[b_rank].size() >= comm_threshold_) { - comm_plan_B[b_rank].push_back(b_in_comm_step[b_rank]); - b_in_comm_step[b_rank].clear(); - } - } - } - } - } - steps.emplace_back(std::make_tuple(gemms, nb_local_gemms, local_gemms)); + long step_idx = 0; + const auto &keymap = keymap_; + const auto &comm_threshold = comm_threshold_; + for (long mm = 0; (nnz_in_AB > 0) && (mm < mns); mm++) { + for (long nn = 0; (nnz_in_AB > 0) && (nn < nns); nn++) { + for (long kk = 0; (nnz_in_AB > 0) && (kk < kns); kk++) { + local_gemms(step_idx, -1, + [&a_sent, &a_in_comm_step, &b_sent, &b_in_comm_step, &keymap, &comm_plan_A, &comm_plan_B, + &comm_threshold, &nnz_in_AB](long m, long n, long k) { + if (nnz_in_AB == 0) return false; + auto r = keymap(Key<2>({m, n})); + auto a_rank = keymap(Key<2>{m, k}); + if (a_sent[a_rank].find({m, k}) == a_sent[a_rank].end()) { + nnz_in_AB--; + a_sent[a_rank].insert({m, k}); + a_in_comm_step[a_rank].insert(std::make_pair(m, k)); + if (a_in_comm_step[a_rank].size() >= comm_threshold) { + comm_plan_A[a_rank].push_back(a_in_comm_step[a_rank]); + a_in_comm_step[a_rank].clear(); + } + } + auto b_rank = keymap(Key<2>{k, n}); + if (b_sent[b_rank].find({k, n}) == b_sent[b_rank].end()) { + nnz_in_AB--; + b_sent[b_rank].insert({k, n}); + b_in_comm_step[b_rank].insert(std::make_pair(k, n)); + if (b_in_comm_step[b_rank].size() >= comm_threshold) { + comm_plan_B[b_rank].push_back(b_in_comm_step[b_rank]); + b_in_comm_step[b_rank].clear(); + } + } + return nnz_in_AB > 0; + }); + step_idx++; } } } + assert(0 == nnz_in_AB); for (long r = 0; r < b_in_comm_step.size(); r++) { if (!b_in_comm_step[r].empty()) { comm_plan_B[r].push_back(b_in_comm_step[r]); @@ -580,57 +555,15 @@ class SpMM { a_in_comm_step[r].clear(); } } - return std::make_tuple(steps, steps_per_tile_A, steps_per_tile_B, comm_plan_A, comm_plan_B); + return std::make_tuple(comm_plan_A, comm_plan_B); } - void display_plan() const { + void display_comm_plan() const { if (!tracing()) return; auto rank = ttg_default_execution_context().rank(); - for (long i = 0; 0 == rank && i < std::get<0>(steps_).size(); i++) { - auto step = std::get<0>(steps_)[i]; - ttg::print("On rank", rank, "step", i, "has", std::get<1>(step), "local GEMMS and", std::get<0>(step).size(), - "GEMMs in total"); - if (rank == 0 && std::get<0>(step).size() < 30) { - std::ostringstream dbg; - dbg << "On rank " << rank << ", Step " << i << " is "; - for (auto it : std::get<0>(step)) { - dbg << "(" << std::get<0>(it) << "," << std::get<1>(it) << "," << std::get<2>(it) << ") "; - } - ttg::print(dbg.str()); - } else { - ttg::print("On rank", rank, - "full plan is not displayed because it is too large or displayed by another process"); - } - } - - const auto &steps_per_tile_A = std::get<1>(steps_); - const auto &steps_per_tile_B = std::get<2>(steps_); - if (0 == rank && steps_per_tile_A.size() <= 32 && steps_per_tile_B.size() <= 32) { - ttg::print("Displaying step list per tile of A on rank", rank); - for (auto const &it : steps_per_tile_A) { - std::stringstream steplist; - for (auto const s : it.second) { - steplist << s << ","; - } - ttg::print("On rank", rank, "rank", std::get<0>(it.first), "runs the following steps for A(", - std::get<1>(it.first), ",", std::get<2>(it.first), "):", steplist.str()); - } - ttg::print("Displaying step list per tile of B on rank", rank); - for (auto const &it : steps_per_tile_B) { - std::stringstream steplist; - for (auto const s : it.second) { - steplist << s << ","; - } - ttg::print("On rank", rank, "rank", std::get<0>(it.first), "runs the following steps for B(", - std::get<1>(it.first), ",", std::get<2>(it.first), "):", steplist.str()); - } - } else { - ttg::print("On rank", rank, "steps per tile of A is", steps_per_tile_A.size(), "too big to display"); - ttg::print("On rank", rank, "steps per tile of B is", steps_per_tile_B.size(), "too big to display"); - } - const auto &comm_plan_A = std::get<3>(steps_); - const auto &comm_plan_B = std::get<4>(steps_); + const auto &comm_plan_A = std::get<0>(comm_plan_); + const auto &comm_plan_B = std::get<1>(comm_plan_); bool display = (rank == 0); for (auto r = 0; display && r < comm_plan_A.size(); r++) { if (comm_plan_A[r].size() > 900) { @@ -754,75 +687,170 @@ class SpMM { abort(); // unreachable } - long nb_steps() const { return std::get<0>(steps_).size(); } + long nb_steps() const { + // assert(mns_ * nns_ * kns_ == std::get<0>(steps_).size()); + return mns_ * nns_ * kns_; + } std::tuple gemm_coordinates(long i, long j, long k) const { long p = i % this->p(); long q = j % this->q(); long r = q * this->p() + p; - for (long s = 0l; s < std::get<0>(steps_).size(); s++) { - const gemmset_t *gs = &std::get<0>(std::get<0>(steps_)[s]); - if (gs->find({i, j, k}) != gs->end()) { - return std::make_tuple(r, s); + + long s = (i / dim_) * nns_ * kns_ + (j / dim_) * kns_ + (k / dim_); + return std::make_tuple(r, s); + } + + template + void local_gemms(long s, long rank, TupleFn &&f) const { + long mm = s / (kns_ * nns_); + long nn = s % (kns_ * nns_) / nns_; + long kk = s % (kns_ * nns_) % nns_; + for (long m = mm * dim_; m < (mm + 1) * dim_ && m < mt_; m++) { + for (long n = nn * dim_; n < (nn + 1) * dim_ && n < nt_; n++) { + auto r = keymap_(Key<2>({m, n})); + if (rank != -1 && r != rank) continue; + const auto &a_k_range = a_rowidx_to_colidx_.at(m); + auto a_iter_fence = std::lower_bound(a_k_range.begin(), a_k_range.end(), (kk + 1) * dim_); + auto a_iter = std::lower_bound(a_k_range.begin(), a_iter_fence, kk * dim_); + if (a_iter == a_iter_fence) continue; + const auto &b_k_range = b_colidx_to_rowidx_.at(n); + auto b_iter_fence = std::lower_bound(b_k_range.begin(), b_k_range.end(), (kk + 1) * dim_); + auto b_iter = std::lower_bound(b_k_range.begin(), b_iter_fence, kk * dim_); + if (b_iter == b_iter_fence) continue; + while (true) { + auto a_colidx = *a_iter; + auto b_rowidx = *b_iter; + while (a_colidx != b_rowidx) { + if (a_colidx < b_rowidx) { + ++a_iter; + if (a_iter == a_iter_fence) break; + a_colidx = *a_iter; + } else { + ++b_iter; + if (b_iter == b_iter_fence) break; + b_rowidx = *b_iter; + } + } + if (a_iter == a_iter_fence) break; + if (b_iter == b_iter_fence) break; + auto ret = f(m, n, a_colidx); + if (!ret) return; + ++a_iter; + if (a_iter == a_iter_fence) break; + ++b_iter; + if (b_iter == b_iter_fence) break; + } } } - abort(); - return std::make_tuple(r, -1); } - struct GemmCoordinate { - long r_; - long c_; - const Blk v_; - - long row() { return r_; } - long col() { return c_; } - const Blk &value() { return v_; } - }; - - const gemmset_t &gemms(long s) const { return std::get<0>(std::get<0>(steps_)[s]); } - const gemmset_t &local_gemms(long s) const { return std::get<2>(std::get<0>(steps_)[s]); } - - long nb_local_gemms(long s) const { return std::get<1>(std::get<0>(steps_)[s]); } + long nb_local_gemms(long s) const { + long nb = 0; + auto count = [&nb](long m, long n, long k) { + nb++; + return true; + }; + local_gemms(s, ttg_default_execution_context().rank(), count); + return nb; + } /// Accessors to the local broadcast steps long first_step_A(long r, long i, long k) const { - const std::set &sv = std::get<1>(steps_).at(std::make_tuple(r, i, k)); - return *sv.begin(); + for (auto j = 0l; j < b_colidx_to_rowidx_.size(); j++) { + auto rank = keymap_(Key<2>{i, j}); + if (rank != r) continue; + for (auto kk = 0l; kk < b_colidx_to_rowidx_[j].size(); kk++) { + auto b_k = b_colidx_to_rowidx_[j][kk]; + if (b_k != k) continue; + long rank_gemm, step; + std::tie(rank_gemm, step) = gemm_coordinates(i, j, k); + assert(rank_gemm == r); + return step; + } + } + assert(0); + return -1; } long first_step_B(long r, long k, long j) const { - const std::set &sv = std::get<2>(steps_).at(std::make_tuple(r, k, j)); - return *sv.begin(); + for (auto i = 0l; i < a_rowidx_to_colidx_.size(); i++) { + auto rank = keymap_(Key<2>{i, j}); + if (rank != r) continue; + for (auto kk = 0l; kk < a_rowidx_to_colidx_[i].size(); kk++) { + auto a_k = a_rowidx_to_colidx_[i][kk]; + if (a_k != k) continue; + long rank_gemm, step; + std::tie(rank_gemm, step) = gemm_coordinates(i, j, k); + assert(rank_gemm == r); + return step; + } + } + assert(0); + return -1; } long next_step_A(long r, long i, long k, long s) const { - const std::set &sv = std::get<1>(steps_).at(std::make_tuple(r, i, k)); - auto it = sv.find(s); - assert(it != sv.end()); - it++; - if (it == sv.end()) return -1; - return *it; + bool found = false; + long step = -1; + for (auto j = 0l; j < b_colidx_to_rowidx_.size(); j++) { + auto rank = keymap_(Key<2>{i, j}); + if (rank != r) continue; + for (auto kk = 0l; kk < b_colidx_to_rowidx_[j].size(); kk++) { + auto b_k = b_colidx_to_rowidx_[j][kk]; + if (b_k != k) continue; + long s2, rank_gemm; + std::tie(rank_gemm, s2) = gemm_coordinates(i, j, k); + assert(rank_gemm == r); + if (found && s != s2) step = s2; + if (s == s2) found = true; + break; + } + if (step != -1) break; + } + return step; } long next_step_B(long r, long k, long j, long s) const { - const std::set &sv = std::get<2>(steps_).at(std::make_tuple(r, k, j)); - auto it = sv.find(s); - assert(it != sv.end()); - it++; - if (it == sv.end()) return -1; - return *it; + bool found = false; + long step = -1; + + for (auto i = 0l; i < a_rowidx_to_colidx_.size(); i++) { + auto rank = keymap_(Key<2>{i, j}); + if (rank != r) continue; + for (auto kk = 0l; kk < a_rowidx_to_colidx_[i].size(); kk++) { + auto a_k = a_rowidx_to_colidx_[i][kk]; + if (a_k != k) continue; + long rank_gemm, s2; + std::tie(rank_gemm, s2) = gemm_coordinates(i, j, k); + assert(rank_gemm == r); + if (found && s != s2) step = s2; + if (s == s2) found = true; + break; + } + if (step != -1) break; + } + return step; } /// Accessors to the communication plan + struct GemmCoordinate { + long r_; + long c_; + const Blk v_; + + long row() { return r_; } + long col() { return c_; } + const Blk &value() { return v_; } + }; long nb_comm_steps(long rank, bool is_a) const { const std::vector *cp; if (is_a) { - cp = &std::get<3>(steps_)[rank]; + cp = &std::get<0>(comm_plan_)[rank]; } else { - cp = &std::get<4>(steps_)[rank]; + cp = &std::get<1>(comm_plan_)[rank]; } return cp->size(); } @@ -832,11 +860,11 @@ class SpMM { std::vector res; const bcastset_t *bset; if (is_a) { - const auto &comm_plan = std::get<3>(steps_)[rank]; + const auto &comm_plan = std::get<0>(comm_plan_)[rank]; if (comm_step >= comm_plan.size()) return res; bset = &(comm_plan[comm_step]); } else { - const auto &comm_plan = std::get<4>(steps_)[rank]; + const auto &comm_plan = std::get<1>(comm_plan_)[rank]; if (comm_step >= comm_plan.size()) return res; bset = &(comm_plan[comm_step]); } @@ -856,9 +884,9 @@ class SpMM { const std::vector *cp; auto r = keymap_(Key<2>({i, j})); if (is_a) { - cp = &std::get<3>(steps_)[r]; + cp = &std::get<0>(comm_plan_)[r]; } else { - cp = &std::get<4>(steps_)[r]; + cp = &std::get<1>(comm_plan_)[r]; } long s; for (s = 0; s < cp->size() - 1; s++) { @@ -873,39 +901,14 @@ class SpMM { long p() const { return P_; } long q() const { return Q_; } - - std::pair gemmsperrankperphase() const { - double mean = 0.0, M2 = 0.0, delta, delta2; - long count = 0; - for (long phase = 0; phase < nb_steps(); phase++) { - const gemmset_t &gemms_in_phase = gemms(phase); - for (long rank = 0; rank < p() * q(); rank++) { - long nbgemm_in_phase_for_rank = 0; - for (auto g : gemms_in_phase) { - if (keymap_(Key<2>({std::get<0>(g), std::get<1>(g)})) == rank) nbgemm_in_phase_for_rank++; - } - double x = (double)nbgemm_in_phase_for_rank; - count++; - delta = x - mean; - mean += delta / count; - delta2 = x - mean; - M2 += delta * delta2; - } - } - if (count > 0) { - return std::make_pair(mean, sqrt(M2 / count)); - } else { - return std::make_pair(mean, nan("undefined")); - } - } }; /// Central coordinator: ensures that all progress according to the plan class Coordinator : public Op, std::tuple, Control>, Out, Control>, Out, Control>>, Coordinator, const Control> { public: - using baseT = - Op, std::tuple, Control>, Out, Control>, Out, Control>>, Coordinator, const Control>; + using baseT = Op, std::tuple, Control>, Out, Control>, Out, Control>>, Coordinator, + const Control>; Coordinator(Edge, Control> progress_ctl, Edge, Control> &a_ctl, Edge, Control> &b_ctl, Edge, Control> &c2c_ctl, std::shared_ptr plan, const Keymap &keymap) @@ -950,24 +953,28 @@ class SpMM { std::unordered_set, tuple_hash> seen_a; std::unordered_set, tuple_hash> seen_b; - for (auto x : plan_->local_gemms(s)) { - long gi, gj, gk; - std::tie(gi, gj, gk) = x; + std::vector> riks_keys; + std::vector> rkjs_keys; + auto bcast_gemms = [&r, &s, &seen_a, &seen_b, &riks_keys, &rkjs_keys](long gi, long gj, long gk) { if (seen_a.find(std::make_tuple(gi, gk)) == seen_a.end()) { if (tracing()) ttg::print("On rank", r, "Coordinator(", r, ", ", s, "): Sending control to LBCastA(", r, ",", gi, ",", gk, ",", s, ")"); - ::send<0>(Key<4>({r, gi, gk, s}), std::get<0>(input), out); + riks_keys.emplace_back(Key<4>({r, gi, gk, s})); seen_a.insert(std::make_tuple(gi, gk)); } if (seen_b.find(std::make_tuple(gk, gj)) == seen_b.end()) { if (tracing()) ttg::print("On rank", r, "Coordinator(", r, ", ", s, "): Sending control to LBCastB(", r, ",", gk, ",", gj, ",", s, ")"); - ::send<1>(Key<4>({r, gk, gj, s}), std::get<0>(input), out); + rkjs_keys.emplace_back(Key<4>({r, gk, gj, s})); seen_b.insert(std::make_tuple(gk, gj)); } - } + return true; + }; + plan_->local_gemms(s, ttg_default_execution_context().rank(), bcast_gemms); + ::broadcast<0>(riks_keys, std::get<0>(input), out); + ::broadcast<1>(rkjs_keys, std::get<0>(input), out); } private: @@ -1183,15 +1190,15 @@ class SpMM { if (tracing()) ttg::print("On rank", rank, "LBcastA(", r, ",", i, ",", k, ",", s, ")"); // broadcast A[i][k] to all local GEMMs in step s, then pass the data to the next step std::vector> ijk_keys; - for (const auto& x : plan_->local_gemms(s)) { - long gi, gj, gk; - std::tie(gi, gj, gk) = x; - if (gi != i || gk != k) continue; + auto bcast = [&ijk_keys, &i, &k, &rank, &s](long gi, long gj, long gk) { + if (gi != i || gk != k) return true; ijk_keys.emplace_back(Key<3>{gi, gj, gk}); if (tracing()) ttg::print("On rank", rank, "Giving A[", gi, ",", gk, "]", "to GEMM(", gi, ",", gj, ",", gk, ") during step", s); - } + return true; + }; + plan_->local_gemms(s, ttg_default_execution_context().rank(), bcast); ::broadcast<0>(ijk_keys, baseT::template get<0>(a_riks), out); auto ns = plan_->next_step_A(r, i, k, s); if (ns > -1) { @@ -1312,15 +1319,15 @@ class SpMM { if (tracing()) ttg::print("On rank", r, "LBcastB(", r, ",", k, ",", j, ",", s, ")"); // broadcast B[k][j] to all local GEMMs in step s, then pass the data to the next step std::vector> ijk_keys; - for (const auto& x : plan_->local_gemms(s)) { - long gi, gj, gk; - std::tie(gi, gj, gk) = x; - if (gj != j || gk != k) continue; + auto bcast = [&ijk_keys, &j, &k, &rank, &s](long gi, long gj, long gk) { + if (gj != j || gk != k) return true; ijk_keys.emplace_back(Key<3>{gi, gj, gk}); if (tracing()) ttg::print("On rank", rank, "Giving B[", gk, ",", gj, "]", "to GEMM(", gi, ",", gj, ",", gk, ") during step", s); - } + return true; + }; + plan_->local_gemms(s, ttg_default_execution_context().rank(), bcast); ::broadcast<0>(ijk_keys, baseT::template get<0>(b_rkjs), out); auto ns = plan_->next_step_B(r, k, j, s); if (ns > -1) { @@ -2025,8 +2032,9 @@ static void initBlSpLibint2(libint2::Operator libint2_op, libint2::any libint2_o const std::vector atoms, const std::string &basis_set_name, double tile_perelem_2norm_threshold, const std::function &)> &keymap, int maxTs, int nthreads, SpMatrix<> &A, SpMatrix<> &B, SpMatrix<> &Aref, SpMatrix<> &Bref, - bool buildRefs, std::string saveShapeId, std::vector &mTiles, std::vector &nTiles, - std::vector &kTiles, std::vector> &a_rowidx_to_colidx, + bool buildRefs, std::string saveShapeId, std::vector &mTiles, + std::vector &nTiles, std::vector &kTiles, + std::vector> &a_rowidx_to_colidx, std::vector> &a_colidx_to_rowidx, std::vector> &b_rowidx_to_colidx, std::vector> &b_colidx_to_rowidx, double &average_tile_volume, @@ -2093,14 +2101,16 @@ static void initBlSpLibint2(libint2::Operator libint2_op, libint2::any libint2_o mTiles = bsTiles; nTiles = bsTiles; kTiles = bsTiles; - std::cout << "{max,avg} tile size = {" << *std::max_element(bsTiles.begin(), bsTiles.end()) << "," <<(double)bs.nbf()/mTiles.size() << "}" << std::endl; + std::cout << "{max,avg} tile size = {" << *std::max_element(bsTiles.begin(), bsTiles.end()) << "," + << (double)bs.nbf() / mTiles.size() << "}" << std::endl; std::ofstream A_shp_os; bool first_tile = true; const auto saveShape = !saveShapeId.empty(); if (saveShape) { - A_shp_os.open("bspmm.A.id="+saveShapeId+".bs=" + basis_set_name + ".T=" + std::to_string(maxTs) + - ".eps=" + std::to_string(tile_perelem_2norm_threshold) + ".nb", std::ios_base::out | std::ios_base::trunc); + A_shp_os.open("bspmm.A.id=" + saveShapeId + ".bs=" + basis_set_name + ".T=" + std::to_string(maxTs) + + ".eps=" + std::to_string(tile_perelem_2norm_threshold) + ".nb", + std::ios_base::out | std::ios_base::trunc); A_shp_os << "SparseArray[{" << std::endl; } @@ -2183,7 +2193,7 @@ static void initBlSpLibint2(libint2::Operator libint2_op, libint2::any libint2_o ref_elements.emplace_back(row_tile_idx, col_tile_idx, tile); if (saveShape) { A_shp_os << (first_tile ? "" : ", ") << "{" << row_tile_idx + 1 << "," << col_tile_idx + 1 - << "} -> " << std::fixed << tile_perelem_2norm << std::endl; + << "} -> " << std::fixed << tile_perelem_2norm << std::endl; first_tile = false; } } @@ -2204,13 +2214,9 @@ static void initBlSpLibint2(libint2::Operator libint2_op, libint2::any libint2_o }; parallel_do(fill_matrix_impl); - if (saveShape) - A_shp_os << "}]" << std::endl; - - long nnz_tiles = elements.size(); // # of nonzero tiles, currently on this rank only + if (saveShape) A_shp_os << "}]" << std::endl; - // allreduce metadata: rowidx_to_colidx, colidx_to_rowidx, total_tile_volume, nnz_tiles - ttg_sum(ttg_default_execution_context(), nnz_tiles); + // allreduce metadata: rowidx_to_colidx, colidx_to_rowidx, total_tile_volume ttg_sum(ttg_default_execution_context(), total_tile_volume); auto allreduce_vevveclong = [&](std::vector> &vvl) { std::vector> vvl_result(vvl.size()); @@ -2265,13 +2271,14 @@ static void initBlSpLibint2(libint2::Operator libint2_op, libint2::any libint2_o #endif // defined(BLOCK_SPARSE_GEMM) static SpMatrix<> timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::function &)> &keymap, - const std::string &tiling_type, double gflops, double avg_nb, double Adensity, - double Bdensity, const std::vector> &a_rowidx_to_colidx, - const std::vector> &a_colidx_to_rowidx, - const std::vector> &b_rowidx_to_colidx, - const std::vector> &b_colidx_to_rowidx, std::vector &mTiles, - std::vector &nTiles, std::vector &kTiles, int M, int N, int K, int P, int Q, - size_t memory, const long forced_split, long lookahead, long comm_threshold) { + const std::string &tiling_type, double gflops, double avg_nb, double Adensity, + double Bdensity, const std::vector> &a_rowidx_to_colidx, + const std::vector> &a_colidx_to_rowidx, + const std::vector> &b_rowidx_to_colidx, + const std::vector> &b_colidx_to_rowidx, std::vector &mTiles, + std::vector &nTiles, std::vector &kTiles, int M, int N, int K, int P, + int Q, size_t memory, const long forced_split, long lookahead, + long comm_threshold) { int MT = (int)A.rows(); int NT = (int)B.cols(); int KT = (int)A.cols(); @@ -2288,8 +2295,12 @@ static SpMatrix<> timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::fun Write_SpMatrix<> c(C, eC, keymap); auto &c_status = c.status(); assert(!has_value(c_status)); + auto constr_begin = std::chrono::high_resolution_clock::now(); SpMM<> a_times_b(ctl, eC, A, B, a_rowidx_to_colidx, a_colidx_to_rowidx, b_rowidx_to_colidx, b_colidx_to_rowidx, mTiles, nTiles, kTiles, keymap, P, Q, memory, forced_split, lookahead, comm_threshold); + auto constr_end = std::chrono::high_resolution_clock::now(); + double constr_duration = + std::chrono::duration_cast(constr_end - constr_begin).count() / 1e6; TTGUNUSED(a_times_b); auto connected = make_graph_executable(&control, a_times_b.get_reada(), a_times_b.get_readb()); @@ -2297,16 +2308,12 @@ static SpMatrix<> timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::fun TTGUNUSED(connected); MPI_Barrier(MPI_COMM_WORLD); - struct timeval start { - 0 - }, end{0}, diff{0}; - gettimeofday(&start, nullptr); + auto run_begin = std::chrono::high_resolution_clock::now(); // ready, go! need only 1 kick, so must be done by 1 thread only if (ttg_default_execution_context().rank() == 0) control.start(a_times_b.initbound()); ttg_fence(ttg_default_execution_context()); - gettimeofday(&end, nullptr); - timersub(&end, &start, &diff); - double tc = (double)diff.tv_sec + (double)diff.tv_usec / 1e6; + auto run_end = std::chrono::high_resolution_clock::now(); + double run_duration = std::chrono::duration_cast(run_end - run_begin).count() / 1e6; #if defined(TTG_USE_MADNESS) std::string rt("MAD"); #elif defined(TTG_USE_PARSEC) @@ -2315,15 +2322,12 @@ static SpMatrix<> timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::fun std::string rt("Unkown???"); #endif if (ttg_default_execution_context().rank() == 0) { - double avg, stdev; - std::tie(avg, stdev) = a_times_b.gemmsperrankperphase(); - std::cout << "TTG-" << rt << " PxQxg= " << P << " " << Q << " 1 average_NB= " << avg_nb << " M= " << M << " N= " << N << " K= " << K << " Tiling= " << tiling_type << " A_density= " << Adensity - << " B_density= " << Bdensity << " gflops= " << gflops << " seconds= " << tc - << " gflops/s= " << gflops / tc << " nb_phases= " << a_times_b.nbphases() << " lookahead= " << lookahead - << " average_nb_gemm_per_rank_per_phase= " << avg << " stdev_nb_gemm_per_rank_per_phase= " << stdev - << std::endl; + << " B_density= " << Bdensity << " gflops= " << gflops << " seconds= " << run_duration + << " gflops/s= " << gflops / run_duration << " nb_phases= " << a_times_b.nbphases() + << " lookahead= " << lookahead << " construction_duration= " << constr_duration + << " gflop/s_with_construction= " << gflops / (constr_duration + run_duration) << std::endl; } return C; @@ -2568,13 +2572,14 @@ int main(int argc, char **argv) { std::cerr << "#Generating matrices with Libint2 on " << xyz_filename << " and " << cores << " cores" << std::endl; auto start = std::chrono::high_resolution_clock::now(); initBlSpLibint2(libint2::Operator::yukawa, libint2::any{op_param}, atoms, basis_name, - tile_perelem_2norm_threshold, bc_keymap, maxTs, cores == -1 ? 1 : cores, A, B, - Aref, Bref, check, saveShapeId, mTiles, nTiles,kTiles, a_rowidx_to_colidx, - a_colidx_to_rowidx, b_rowidx_to_colidx, b_colidx_to_rowidx, avg_nb,Adensity, - Bdensity); + tile_perelem_2norm_threshold, bc_keymap, maxTs, cores == -1 ? 1 : cores, A, B, Aref, Bref, check, + saveShapeId, mTiles, nTiles, kTiles, a_rowidx_to_colidx, a_colidx_to_rowidx, b_rowidx_to_colidx, + b_colidx_to_rowidx, avg_nb, Adensity, Bdensity); auto end = std::chrono::high_resolution_clock::now(); - auto duration = duration_cast(end-start); - std::cerr << "#Generation done (" << duration.count()/1000000. << "s)" << std::endl; + auto duration = std::chrono::duration_cast(end - start); + std::cerr << "#Generation done (" << duration.count() / 1000000. << "s)" << std::endl; + std::cerr << "#Adensity " << Adensity << " mt " << mTiles.size() << " nt " << nTiles.size() << " kt " + << kTiles.size() << std::endl; tiling_type << xyz_filename << "_" << basis_name << "_" << tile_perelem_2norm_threshold << "_" << op_param; #endif C.resize(A.rows(), B.cols()); @@ -2600,9 +2605,9 @@ int main(int argc, char **argv) { // Start up engine ttg_execute(ttg_default_execution_context()); for (int nrun = 0; nrun < nb_runs; nrun++) { - C = timed_measurement(A, B, bc_keymap, tiling_type.str(), gflops, avg_nb, Adensity, Bdensity, a_rowidx_to_colidx, - a_colidx_to_rowidx, b_rowidx_to_colidx, b_colidx_to_rowidx, mTiles, nTiles, kTiles, M, N, K, - P, Q, memory, forced_split, lookahead, comm_threshold); + C = timed_measurement(A, B, bc_keymap, tiling_type.str(), gflops, avg_nb, Adensity, Bdensity, + a_rowidx_to_colidx, a_colidx_to_rowidx, b_rowidx_to_colidx, b_colidx_to_rowidx, mTiles, + nTiles, kTiles, M, N, K, P, Q, memory, forced_split, lookahead, comm_threshold); } #ifdef BSPMM_BUILD_TA_TEST @@ -2611,13 +2616,16 @@ int main(int argc, char **argv) { auto MT = mTiles.size(); auto NT = nTiles.size(); auto KT = kTiles.size(); - auto& mad_world = ttg_default_execution_context().impl().impl(); + auto &mad_world = ttg_default_execution_context().impl().impl(); // make tranges - auto make_trange1 = [](const auto& tile_sizes) { - std::vector hashes; hashes.reserve(tile_sizes.size()+1); + auto make_trange1 = [](const auto &tile_sizes) { + std::vector hashes; + hashes.reserve(tile_sizes.size() + 1); hashes.push_back(0); - for(auto& tile_size: tile_sizes) { hashes.push_back(hashes.back() + tile_size); } + for (auto &tile_size : tile_sizes) { + hashes.push_back(hashes.back() + tile_size); + } return TiledArray::TiledRange1(hashes.begin(), hashes.end()); }; auto mtr1 = make_trange1(mTiles); @@ -2628,13 +2636,13 @@ int main(int argc, char **argv) { TA::TiledRange C_trange({mtr1, ntr1}); // make shapes - auto make_shape = [&mad_world](const SpMatrix<>& mat, const auto& trange) { + auto make_shape = [&mad_world](const SpMatrix<> &mat, const auto &trange) { TA::Tensor norms(TA::Range(mat.rows(), mat.cols()), 0.); - for (int k=0; k::InnerIterator it(mat, k); it; ++it) { auto r = it.row(); // row index auto c = it.col(); // col index (here it is equal to k) - const auto& v = it.value(); + const auto &v = it.value(); norms(r, c) = std::sqrt(btas::dot(v, v)); } } @@ -2645,14 +2653,23 @@ int main(int argc, char **argv) { auto C_shape = make_shape(C, C_trange); // make pmaps - auto A_pmap = std::make_shared(mad_world, MT*KT, [&](size_t mk) -> size_t { auto [m, k] = std::div((long)mk, (long)KT); return tile2rank(m, k, P, Q); } ); - auto B_pmap = std::make_shared(mad_world, KT*NT, [&](size_t kn) -> size_t { auto [k, n] = std::div((long)kn, (long)NT); return tile2rank(k, n, P, Q); } ); - auto C_pmap = std::make_shared(mad_world, MT*NT, [&](size_t mn) -> size_t { auto [m, n] = std::div((long)mn, (long)NT); return tile2rank(m, n, P, Q); } ); + auto A_pmap = std::make_shared(mad_world, MT * KT, [&](size_t mk) -> size_t { + auto [m, k] = std::div((long)mk, (long)KT); + return tile2rank(m, k, P, Q); + }); + auto B_pmap = std::make_shared(mad_world, KT * NT, [&](size_t kn) -> size_t { + auto [k, n] = std::div((long)kn, (long)NT); + return tile2rank(k, n, P, Q); + }); + auto C_pmap = std::make_shared(mad_world, MT * NT, [&](size_t mn) -> size_t { + auto [m, n] = std::div((long)mn, (long)NT); + return tile2rank(m, n, P, Q); + }); // make distarrays - auto make_ta = [&mad_world](const SpMatrix<>& mat, const auto& trange, const auto& shape, const auto& pmap) { + auto make_ta = [&mad_world](const SpMatrix<> &mat, const auto &trange, const auto &shape, const auto &pmap) { TA::TSpArrayD mat_ta(mad_world, trange, shape, pmap); - for (int k=0; k::InnerIterator it(mat, k); it; ++it) { auto r = it.row(); // row index auto c = it.col(); // col index (here it is equal to k) @@ -2671,10 +2688,10 @@ int main(int argc, char **argv) { C_ta("m,n") = (A_ta("m,k") * B_ta("k,n")).set_shape(C_shape); C_ta.world().gop.fence(); auto end = std::chrono::high_resolution_clock::now(); - auto duration = duration_cast(end-start); - std::cout << "Time to compute C=A*B in TiledArray = " << duration.count()/1000000. << std::endl; - auto print = [](const auto& label, const SpMatrix<>& mat) { - for (int k=0; k(end - start); + std::cout << "Time to compute C=A*B in TiledArray = " << duration.count() / 1000000. << std::endl; + auto print = [](const auto &label, const SpMatrix<> &mat) { + for (int k = 0; k < mat.outerSize(); ++k) { for (SpMatrix<>::InnerIterator it(mat, k); it; ++it) { auto r = it.row(); // row index auto c = it.col(); // col index (here it is equal to k) @@ -2682,10 +2699,10 @@ int main(int argc, char **argv) { } } }; -// print("A", A); -// print("C", C); -// std::cout << "A_ta = " << A_ta << std::endl; -// std::cout << "C_ta = " << C_ta << std::endl; + // print("A", A); + // print("C", C); + // std::cout << "A_ta = " << A_ta << std::endl; + // std::cout << "C_ta = " << C_ta << std::endl; } } #endif @@ -2729,7 +2746,7 @@ int main(int argc, char **argv) { std::cout << "||Cref - C||_2 = " << std::sqrt(norm_2_square) << std::endl; std::cout << "||Cref - C||_\\infty = " << norm_inf << std::endl; if (norm_inf > 1e-9) { - if(Cref.nonZeros() < 100) { + if (Cref.nonZeros() < 100) { std::cout << "Cref:\n" << Cref << std::endl; std::cout << "C:\n" << C << std::endl; }