Skip to content

Commit 62b7d1e

Browse files
committed
Merge branch 'master' into kmp5/feature/CP
2 parents fbe772a + e574248 commit 62b7d1e

File tree

10 files changed

+117
-59
lines changed

10 files changed

+117
-59
lines changed

INSTALL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Both methods are supported. However, for most users we _strongly_ recommend to b
4343
- [Range-V3](https://github.com/ericniebler/range-v3.git) -- a Ranges library that served as the basis for Ranges component of C++20 and later.
4444
- [BTAS](http://github.com/ValeevGroup/BTAS), tag 1cfcb12647c768ccd83b098c64cda723e1275e49 . If usable BTAS installation is not found, TiledArray will download and compile
4545
BTAS from source. *This is the recommended way to compile BTAS for all users*.
46-
- [MADNESS](https://github.com/m-a-d-n-e-s-s/madness), tag 93a9a5cec2a8fa87fba3afe8056607e6062a9058 .
46+
- [MADNESS](https://github.com/m-a-d-n-e-s-s/madness), tag ef97ad1f0080da04f9592f03185c1a331cd5e001 .
4747
Only the MADworld runtime and BLAS/LAPACK C API component of MADNESS is used by TiledArray.
4848
If usable MADNESS installation is not found, TiledArray will download and compile
4949
MADNESS from source. *This is the recommended way to compile MADNESS for all users*.

cmake/modules/FindOrFetchMADWorld.cmake

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ if (NOT TARGET MADworld)
4141

4242
# look for C and MPI here to make troubleshooting easier and be able to override defaults for MADNESS
4343
enable_language(C)
44-
find_package(MPI REQUIRED COMPONENTS C CXX)
44+
find_package(MPI REQUIRED COMPONENTS C)
4545

46+
set(FETCHCONTENT_QUIET FALSE)
4647
include(FetchContent)
4748
FetchContent_Declare(
4849
MADNESS
49-
GIT_REPOSITORY https://github.com/m-a-d-n-e-s-s/madness.git
50+
GIT_REPOSITORY https://github.com/m-a-d-n-e-s-s/madness.git
5051
GIT_TAG ${TA_TRACKED_MADNESS_TAG}
52+
GIT_PROGRESS TRUE
5153
)
5254
FetchContent_MakeAvailable(MADNESS)
5355
FetchContent_GetProperties(MADNESS

external/versions.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ set(TA_INSTALL_EIGEN_PREVIOUS_VERSION 3.3.7)
1111
set(TA_INSTALL_EIGEN_URL_HASH SHA256=b4c198460eba6f28d34894e3a5710998818515104d6e74e5cc331ce31e46e626)
1212
set(TA_INSTALL_EIGEN_PREVIOUS_URL_HASH MD5=b9e98a200d2455f06db9c661c5610496)
1313

14-
set(TA_TRACKED_MADNESS_TAG 8da56b1fc0b3d6eabe155923fb844f3430fc7d05)
15-
set(TA_TRACKED_MADNESS_PREVIOUS_TAG 93a9a5cec2a8fa87fba3afe8056607e6062a9058)
14+
set(TA_TRACKED_MADNESS_TAG bd84a52766ab497dedc2f15f2162fb0eb7ec4653)
15+
set(TA_TRACKED_MADNESS_PREVIOUS_TAG ef97ad1f0080da04f9592f03185c1a331cd5e001)
1616
set(TA_TRACKED_MADNESS_VERSION 0.10.1)
1717
set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1)
1818

python/src/TiledArray/python/module.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,8 @@ static World &initialize() {
4040
// this loads MPI before TA tries to do it
4141
int initialized = 0;
4242
MPI_Initialized(&initialized);
43-
MPI_Comm ta_comm;
4443

45-
if (!initialized) {
46-
ta_comm = MPI_COMM_WORLD;
47-
} else {
44+
if (initialized) {
4845
int thread_level;
4946
MPI_Query_thread(&thread_level);
5047
if (thread_level != MPI_THREAD_MULTIPLE)
@@ -58,10 +55,10 @@ static World &initialize() {
5855
char *_argv[0];
5956
char **argv = _argv;
6057
if (!madness::initialized()) {
61-
madness::initialize(argc, argv, ta_comm);
58+
madness::initialize(argc, argv);
6259
initialized_madness = true;
6360
}
64-
TiledArray::World &world = TiledArray::initialize(argc, argv, ta_comm);
61+
TiledArray::World &world = TiledArray::initialize(argc, argv, MPI_COMM_WORLD);
6562
TiledArray::set_default_world(world);
6663
if (world.rank() == 0) {
6764
std::cout << "initialized TA in a world with " << world.size() << " ranks"

src/TiledArray/conversions/make_array.h

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include "TiledArray/array_impl.h"
3030
#include "TiledArray/external/madness.h"
31+
#include "TiledArray/pmap/replicated_pmap.h"
3132
#include "TiledArray/shape.h"
3233
#include "TiledArray/type_traits.h"
3334

@@ -73,7 +74,7 @@ template <typename Array, typename Op,
7374
typename std::enable_if<is_dense<Array>::value>::type* = nullptr>
7475
inline Array make_array(
7576
World& world, const detail::trange_t<Array>& trange,
76-
const std::shared_ptr<const detail::pmap_t<Array> >& pmap, Op&& op) {
77+
const std::shared_ptr<const detail::pmap_t<Array>>& pmap, Op&& op) {
7778
typedef typename Array::value_type value_type;
7879
typedef typename value_type::range_type range_type;
7980

@@ -150,10 +151,10 @@ template <typename Array, typename Op,
150151
typename std::enable_if<!is_dense<Array>::value>::type* = nullptr>
151152
inline Array make_array(
152153
World& world, const detail::trange_t<Array>& trange,
153-
const std::shared_ptr<const detail::pmap_t<Array> >& pmap, Op&& op) {
154+
const std::shared_ptr<const detail::pmap_t<Array>>& pmap, Op&& op) {
154155
typedef typename Array::value_type value_type;
155156
typedef typename Array::ordinal_type ordinal_type;
156-
typedef std::pair<ordinal_type, Future<value_type> > datum_type;
157+
typedef std::pair<ordinal_type, Future<value_type>> datum_type;
157158

158159
// Create a vector to hold local tiles
159160
std::vector<datum_type> tiles;
@@ -241,6 +242,41 @@ inline Array make_array(World& world, const detail::trange_t<Array>& trange,
241242
op);
242243
}
243244

245+
/// a make_array variant that uses a sequence of tiles
246+
/// to construct a DistArray with default pmap
247+
template <typename Array, typename Tiles>
248+
Array make_array(World& world, const detail::trange_t<Array>& tiled_range,
249+
Tiles begin, Tiles end, bool replicated) {
250+
Array array;
251+
using Tuple = std::remove_reference_t<decltype(*begin)>;
252+
using Index = std::tuple_element_t<0, Tuple>;
253+
using shape_type = typename Array::shape_type;
254+
255+
std::shared_ptr<typename Array::pmap_interface> pmap;
256+
if (replicated) {
257+
size_t ntiles = tiled_range.tiles_range().volume();
258+
pmap = std::make_shared<detail::ReplicatedPmap>(world, ntiles);
259+
}
260+
261+
if constexpr (shape_type::is_dense()) {
262+
array = Array(world, tiled_range, pmap);
263+
} else {
264+
std::vector<std::pair<Index, float>> tile_norms;
265+
for (Tiles it = begin; it != end; ++it) {
266+
auto [index, tile] = *it;
267+
tile_norms.push_back({index, tile.norm()});
268+
}
269+
shape_type shape(world, tile_norms, tiled_range);
270+
array = Array(world, tiled_range, shape, pmap);
271+
}
272+
for (Tiles it = begin; it != end; ++it) {
273+
auto [index, tile] = *it;
274+
if (array.is_zero(index)) continue;
275+
array.set(index, tile);
276+
}
277+
return array;
278+
}
279+
244280
} // namespace TiledArray
245281

246282
#endif // TILEDARRAY_CONVERSIONS_MAKE_ARRAY_H__INCLUDED

src/TiledArray/dense_shape.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,17 @@ class DenseShape {
8989
return false;
9090
}
9191

92+
/// Check that a tile is zero
93+
94+
/// \tparam Integer an integer type
95+
/// \param i the index
96+
/// \return false
97+
template <typename Integer>
98+
std::enable_if_t<std::is_integral_v<Integer>, bool> is_zero(
99+
const std::initializer_list<Integer>& i) const {
100+
return false;
101+
}
102+
92103
/// Check density
93104

94105
/// \return true

src/TiledArray/dist_array.h

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,9 @@ class DistArray : public madness::archive::ParallelSerializableObject {
487487
/// initialized using the `op` function/functor, which transforms
488488
/// each tile in `other` using `op`
489489
/// \param other The array to be copied
490-
template <typename OtherTile, typename Op>
490+
template <typename OtherTile, typename Op,
491+
typename = std::enable_if_t<
492+
!std::is_same_v<detail::remove_cvr_t<Op>, TiledRange>>>
491493
DistArray(const DistArray<OtherTile, Policy>& other, Op&& op) : pimpl_() {
492494
*this = foreach<Tile>(other, std::forward<Op>(op));
493495
}
@@ -1878,39 +1880,6 @@ auto norm2(const DistArray<Tile, Policy>& a) {
18781880
return std::sqrt(squared_norm(a));
18791881
}
18801882

1881-
template <typename Array, typename Tiles>
1882-
Array make_array(World& world, const detail::trange_t<Array>& tiled_range,
1883-
Tiles begin, Tiles end, bool replicated) {
1884-
Array array;
1885-
using Tuple = std::remove_reference_t<decltype(*begin)>;
1886-
using Index = std::tuple_element_t<0, Tuple>;
1887-
using shape_type = typename Array::shape_type;
1888-
1889-
std::shared_ptr<typename Array::pmap_interface> pmap;
1890-
if (replicated) {
1891-
size_t ntiles = tiled_range.tiles_range().volume();
1892-
pmap = std::make_shared<detail::ReplicatedPmap>(world, ntiles);
1893-
}
1894-
1895-
if constexpr (shape_type::is_dense()) {
1896-
array = Array(world, tiled_range, pmap);
1897-
} else {
1898-
std::vector<std::pair<Index, float>> tile_norms;
1899-
for (Tiles it = begin; it != end; ++it) {
1900-
auto [index, tile] = *it;
1901-
tile_norms.push_back({index, tile.norm()});
1902-
}
1903-
shape_type shape(world, tile_norms, tiled_range);
1904-
array = Array(world, tiled_range, shape, pmap);
1905-
}
1906-
for (Tiles it = begin; it != end; ++it) {
1907-
auto [index, tile] = *it;
1908-
if (array.is_zero(index)) continue;
1909-
array.set(index, tile);
1910-
}
1911-
return array;
1912-
}
1913-
19141883
template <typename T, typename P>
19151884
DistArray<T, P> replicated(const DistArray<T, P>& a) {
19161885
auto& world = a.world();

src/TiledArray/einsum/tiledarray.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef TILEDARRAY_EINSUM_TILEDARRAY_H__INCLUDED
22
#define TILEDARRAY_EINSUM_TILEDARRAY_H__INCLUDED
33

4+
#include "TiledArray/conversions/make_array.h"
45
#include "TiledArray/dist_array.h"
56
#include "TiledArray/einsum/index.h"
67
#include "TiledArray/einsum/range.h"
@@ -410,6 +411,12 @@ template <DeNest DeNestFlag = DeNest::False, typename ArrayA_, typename ArrayB_,
410411
auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
411412
std::tuple<Einsum::Index<std::string>, Indices...> cs,
412413
World &world) {
414+
// hotfix: process all preceding tasks before entering this code with many
415+
// blocking calls
416+
// TODO figure out why having free threads left after blocking MPI split
417+
// still not enough to ensure progress
418+
world.gop.fence();
419+
413420
using ArrayA = std::remove_cv_t<ArrayA_>;
414421
using ArrayB = std::remove_cv_t<ArrayB_>;
415422

@@ -739,7 +746,8 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
739746
for (Index h : H.tiles) {
740747
auto &[A, B] = AB;
741748
auto own = A.own(h) || B.own(h);
742-
auto comm = world.mpi.comm().Split(own, world.rank());
749+
auto comm = madness::blocking_invoke(&SafeMPI::Intracomm::Split,
750+
world.mpi.comm(), own, world.rank());
743751
worlds.push_back(std::make_unique<World>(comm));
744752
auto &owners = worlds.back();
745753
if (!own) continue;

src/TiledArray/sparse_shape.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,17 @@ class SparseShape {
509509
return tile_norms_[i] < my_threshold_;
510510
}
511511

512+
/// Check that a tile is zero
513+
514+
/// \tparam Integer an integer type
515+
/// \param i the index
516+
/// \return true if tile at position \p i is zero
517+
template <typename Integer>
518+
std::enable_if_t<std::is_integral_v<Integer>, bool> is_zero(
519+
const std::initializer_list<Integer>& i) const {
520+
return this->is_zero<std::initializer_list<Integer>>(i);
521+
}
522+
512523
/// Check density
513524

514525
/// \return true

src/TiledArray/tensor/tensor.h

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,7 @@ class Tensor {
431431
auto volume = total_size();
432432
for (decltype(volume) i = 0; i < volume; ++i) {
433433
auto& el = *(data() + i);
434-
if (!el.empty())
435-
el = p(el, inner_perm);
434+
if (!el.empty()) el = p(el, inner_perm);
436435
}
437436
}
438437
}
@@ -1633,10 +1632,19 @@ class Tensor {
16331632
if (right.empty()) return *this;
16341633
return binary(
16351634
right,
1636-
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
1635+
[](const value_type& l, const value_t<Right>& r) -> decltype(l + r) {
16371636
if constexpr (detail::is_tensor_v<value_type>) {
1638-
if (l.empty() && r.empty())
1639-
return value_type{};
1637+
if (l.empty()) {
1638+
if (r.empty())
1639+
return {};
1640+
else
1641+
return r;
1642+
} else {
1643+
if (r.empty())
1644+
return l;
1645+
else
1646+
return l + r;
1647+
}
16401648
}
16411649
return l + r;
16421650
});
@@ -1799,8 +1807,23 @@ class Tensor {
17991807
detail::tensors_have_equal_nested_rank_v<Tensor, Right>>>
18001808
Tensor subt(const Right& right) const {
18011809
return binary(
1802-
right, [](const value_type& l, const value_type& r) -> decltype(auto) {
1803-
return l - r;
1810+
right,
1811+
[](const value_type& l, const value_t<Right>& r) -> decltype(l - r) {
1812+
if constexpr (detail::is_tensor_v<value_type>) {
1813+
if (l.empty()) {
1814+
if (r.empty())
1815+
return {};
1816+
else
1817+
return -r;
1818+
} else {
1819+
if (r.empty())
1820+
return l;
1821+
else
1822+
return l - r;
1823+
}
1824+
} else {
1825+
return l - r;
1826+
}
18041827
});
18051828
}
18061829

@@ -1936,13 +1959,14 @@ class Tensor {
19361959
typename std::enable_if<detail::is_nested_tensor_v<Right>>::type* =
19371960
nullptr>
19381961
decltype(auto) mult(const Right& right) const {
1939-
1940-
auto mult_op =[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
1962+
auto mult_op = [](const value_type& l,
1963+
const value_t<Right>& r) -> decltype(auto) {
19411964
return l * r;
19421965
};
19431966

19441967
if (empty() || right.empty()) {
1945-
using res_t = decltype(std::declval<Tensor>().binary(std::declval<Right>(), mult_op));
1968+
using res_t = decltype(std::declval<Tensor>().binary(
1969+
std::declval<Right>(), mult_op));
19461970
return res_t{};
19471971
}
19481972

0 commit comments

Comments
 (0)