diff --git a/CMakeLists.txt b/CMakeLists.txt index 730d92abea..64fcb8dfc1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ # Jul 19, 2013 # -cmake_minimum_required (VERSION 3.15.0) # need list(PREPEND for toolchains +cmake_minimum_required (VERSION 3.21.0) # for HIP/ROCm # Set TiledArray version ======================================================= @@ -264,17 +264,13 @@ vgkit_cmake_git_metadata() ########################## # Check compiler features ########################## -# need C++17, insist on strict standard -set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ ISO Standard version") -if (NOT(CMAKE_CXX_STANDARD EQUAL 17 OR CMAKE_CXX_STANDARD EQUAL 20)) - message(FATAL_ERROR "C++ 2017 ISO Standard or higher is required to compile TiledArray") -endif() -# C++20 is only configurable via compile features with cmake 3.12 and older -if (CMAKE_CXX_STANDARD EQUAL 20 AND CMAKE_VERSION VERSION_LESS 3.12.0) - cmake_minimum_required (VERSION 3.12.0) +# need C++20, insist on strict standard +set(CMAKE_CXX_STANDARD 20 CACHE STRING "C++ ISO Standard version") +if (CMAKE_CXX_STANDARD LESS 20) + message(FATAL_ERROR "C++ 2020 ISO Standard or higher is required to compile TiledArray") endif() set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CXX_EXTENSIONS OFF CACHE BOOL "Whether to use extensions of C++ ISO Standard version") +set(CMAKE_CXX_EXTENSIONS OFF CACHE BOOL "Whether to use extensions of C++ ISO Standard version") # Check type support include(CheckTypeSize) check_type_size("long double" TILEDARRAY_HAS_LONG_DOUBLE LANGUAGE CXX) diff --git a/INSTALL.md b/INSTALL.md index 742d967f71..677f976dbb 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -24,15 +24,14 @@ Both methods are supported. However, for most users we _strongly_ recommend to b ## Prerequisites -- C++ compiler with support for the [C++17 standard](http://www.iso.org/standard/68564.html), or a more recent standard. This includes the following compilers: - - [GNU C++](https://gcc.gnu.org/), version 7.0 or higher - - [Clang](https://clang.llvm.org/), version 5 or higher - - [Apple Clang](https://en.wikipedia.org/wiki/Xcode), version 9.3 or higher - - [Intel C++ compiler](https://software.intel.com/en-us/c-compilers), version 19 or higher +- C++ compiler with support for the [C++20 standard](http://www.iso.org/standard/68564.html), or a more recent standard. This includes the following compilers: + - [GNU C++](https://gcc.gnu.org/), version 11 or higher + - [Clang](https://clang.llvm.org/), version 14 or higher + - [Apple Clang](https://en.wikipedia.org/wiki/Xcode), version 14 or higher See the current [Travis CI matrix](.travis.yml) for the most up-to-date list of compilers that are known to work. -- [CMake](https://cmake.org/), version 3.15 or higher; if {CUDA,HIP} support is needed, CMake {3.18,3.21} or higher is required. +- [CMake](https://cmake.org/), version 3.21 or higher. - [Git](https://git-scm.com/) 1.8 or later (required to obtain TiledArray and MADNESS source code from GitHub) - [Eigen](http://eigen.tuxfamily.org/), version 3.3.5 or higher; if CUDA is enabled then 3.3.7 is required (will be downloaded automatically, if missing) - [Boost libraries](www.boost.org/), version 1.81 or higher (will be downloaded automatically, if missing). The following principal Boost components are used: @@ -66,14 +65,14 @@ Compiling BTAS requires the following prerequisites: Optional prerequisites: - for execution on GPGPUs: - device programming runtime: - - [CUDA compiler and runtime](https://developer.nvidia.com/cuda-zone) -- for execution on NVIDIA's CUDA-enabled accelerators. CUDA 11 or later is required. + - [CUDA compiler and runtime](https://developer.nvidia.com/cuda-zone) -- for execution on NVIDIA's CUDA-enabled accelerators. CUDA 12 or later is required. - [HIP/ROCm compiler and runtime](https://developer.nvidia.com/cuda-zone) -- for execution on AMD's ROCm-enabled accelerators. Note that TiledArray does not use ROCm directly but its C++ Heterogeneous-Compute Interface for Portability, `HIP`; although HIP can also be used to program CUDA-enabled devices, in TiledArray it is used only to program ROCm devices, hence ROCm and HIP will be used interchangeably. - [LibreTT](github.com/victor-anisimov/LibreTT) -- free tensor transpose library for CUDA, ROCm, and SYCL platforms that is based on the [original cuTT library](github.com/ap-hynninen/cutt) extended to provide thread-safety improvements (via github.com/ValeevGroup/cutt) and extended to non-CUDA platforms by [@victor-anisimov](github.com/victor-anisimov) (tag 6eed30d4dd2a5aa58840fe895dcffd80be7fbece). - [Umpire](github.com/LLNL/Umpire) -- portable memory manager for heterogeneous platforms (tag 8c85866107f78a58403e20a2ae8e1f24c9852287). - [Doxygen](http://www.doxygen.nl/) -- for building documentation (version 1.8.12 or later). - [ScaLAPACK](http://www.netlib.org/scalapack/) -- a distributed-memory linear algebra package. If detected, the following C++ components will also be sought and downloaded, if missing: - - [scalapackpp](https://github.com/wavefunction91/scalapackpp.git) -- a modern C++ (C++17) wrapper for ScaLAPACK (tag 6397f52cf11c0dfd82a79698ee198a2fce515d81); pulls and builds the following additional prerequisite - - [blacspp](https://github.com/wavefunction91/blacspp.git) -- a modern C++ (C++17) wrapper for BLACS + - [scalapackpp](https://github.com/wavefunction91/scalapackpp.git) -- a modern C++ wrapper for ScaLAPACK (tag 6397f52cf11c0dfd82a79698ee198a2fce515d81); pulls and builds the following additional prerequisite + - [blacspp](https://github.com/wavefunction91/blacspp.git) -- a modern C++ wrapper for BLACS - Python3 interpreter -- to test (optionally-built) Python bindings - [TTG](https://github.com/TESSEorg/ttg.git) -- C++ implementation of the Template Task Graph programming model for fine-grained flow-graph composition of distributed memory programs (tag 3fe4a06dbf4b05091269488aab38223da1f8cb8e). @@ -186,7 +185,7 @@ Additional CMake variables are given below. * `CMAKE_BUILD_TYPE` -- Optimization/debug build type options include `Debug` (optimization off, debugging symbols and assersions on), `Release` (optimization on, debugging symbols and assertions off), `RelWithDebInfo` (optimization on, debugging symbols and assertions on) and `MinSizeRel` (same as `Release` but optimized for executable size). The default is empty build type. It is recommended that you set the build type explicitly. * `BUILD_SHARED_LIBS` -- Enable shared libraries. This option is only available if the platform supports shared libraries; if that's true and `TA_ASSUMES_ASLR_DISABLED` is `ON` (see below) the default is `ON`, otherwise the default is `OFF`. -* `CMAKE_CXX_STANDARD` -- Specify the C++ ISO Standard to use. Valid values are `17` (default), and `20`. +* `CMAKE_CXX_STANDARD` -- Specify the C++ ISO Standard to use. Valid values are `20` (default), and `23`. Most of these are best specified in a _toolchain file_. TiledArray is recommended to use the toolchains distributed via [the Valeev Group CMake kit](https://github.com/ValeevGroup/kit-cmake/tree/master/toolchains). TiledArray by default downloads (via [the FetchContent CMake module](https://cmake.org/cmake/help/latest/module/FetchContent.html)) the VG CMake toolkit which makes the toolchains available without having to download the toolchain files manually. E.g., to use toolchain `x` from the VG CMake kit repository provide `-DCMAKE_TOOLCHAIN_FILE=cmake/vg/toolchains/x.cmake` to CMake when configuring TiledArray. diff --git a/doc/dox/dev/Basic-Programming.md b/doc/dox/dev/Basic-Programming.md index ee9f08cc4f..83219ff5e2 100644 --- a/doc/dox/dev/Basic-Programming.md +++ b/doc/dox/dev/Basic-Programming.md @@ -66,7 +66,7 @@ An object that specifies the structure of DistArray. E.g. it could be represente ## Implementation -TiledArray is a library written in standard C++ using features available in the 2017 ISO standard (commonly known as C++17). To use TiledArray it is necessary to `#include` header `tiledarray.h`. imports most TiledArray features into namespace `TiledArray`. For convenience, namespace alias `TA` is also provided. Although the alias can be disabled by defining the `TILEDARRAY_DISABLE_NAMESPACE_TA` preprocessor variable, all examples will assume that the `TA` alias is not disabled. +TiledArray is a library written in standard C++ using features available in the 2020 ISO standard (commonly known as C++20). To use TiledArray it is necessary to `#include` header `tiledarray.h`. imports most TiledArray features into namespace `TiledArray`. For convenience, namespace alias `TA` is also provided. Although the alias can be disabled by defining the `TILEDARRAY_DISABLE_NAMESPACE_TA` preprocessor variable, all examples will assume that the `TA` alias is not disabled. P.S. It sometimes may be possible to reduce source code couplings by importing only _forwarding_ declarations. This is done by `#include`ing header `tiledarray_fwd.h`. diff --git a/src/TiledArray/conversions/make_array.h b/src/TiledArray/conversions/make_array.h index 306b61ee34..09ef97874c 100644 --- a/src/TiledArray/conversions/make_array.h +++ b/src/TiledArray/conversions/make_array.h @@ -242,7 +242,7 @@ inline Array make_array(World& world, const detail::trange_t& trange, op); } -/// a make_array variant that uses a sequence of tiles +/// a make_array variant that uses a sequence of {tile_index,tile} pairs /// to construct a DistArray with default pmap template Array make_array(World& world, const detail::trange_t& tiled_range, diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 6b61610abc..cc5ded01cb 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -600,6 +600,8 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, std::invoke(update_perm_and_indices, std::get<0>(AB)); std::invoke(update_perm_and_indices, std::get<1>(AB)); + // construct result, with "dense" DistArray; the array will be + // reconstructred from local tiles later ArrayTerm C = {ArrayC(world, TiledRange(range_map[c])), c}; for (auto idx : e) { C.tiles *= Range(range_map[idx].tiles_range()); @@ -609,6 +611,16 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, } C.expr = e; + using Index = Einsum::Index; + + // this will collect local tiles of C.array, to be used to rebuild C.array + std::vector> C_local_tiles; + auto build_C_array = [&]() { + C.array = make_array(world, TiledRange(range_map[c]), + C_local_tiles.begin(), C_local_tiles.end(), + /* replicated = */ false); + }; + std::get<0>(AB).expr += inner.a; std::get<1>(AB).expr += inner.b; @@ -627,19 +639,56 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, } } - using Index = Einsum::Index; - if (!e) { // hadamard reduction + auto &[A, B] = AB; TiledRange trange(range_map[i]); RangeProduct tiles; for (auto idx : i) { tiles *= Range(range_map[idx].tiles_range()); } + + // the inner product can be either hadamard or a contraction + using TensorT = typename decltype(A.array)::value_type::value_type; + static_assert( + std::is_same_v); + constexpr bool is_tot = detail::is_tensor_v; + auto element_hadamard_op = + (is_tot && inner.h) + ? std::make_optional( + [&inner, plan = detail::TensorHadamardPlan(inner.A, inner.B, + inner.C)]( + auto const &l, auto const &r) -> TensorT { + if (l.empty() || r.empty()) return TensorT{}; + return detail::tensor_hadamard(l, r, plan); + }) + : std::nullopt; + auto element_contract_op = + (is_tot && !inner.h) + ? std::make_optional( + [&inner, plan = detail::TensorContractionPlan( + inner.A, inner.B, inner.C)]( + auto const &l, auto const &r) -> TensorT { + if (l.empty() || r.empty()) return TensorT{}; + return detail::tensor_contract(l, r, plan); + }) + : std::nullopt; + auto element_product_op = [&inner, &element_hadamard_op, + &element_contract_op]( + auto const &l, auto const &r) -> TensorT { + TA_ASSERT(inner.h ? element_hadamard_op.has_value() + : element_contract_op.has_value()); + return inner.h ? element_hadamard_op.value()(l, r) + : element_contract_op.value()(l, r); + }; + auto pa = A.permutation; auto pb = B.permutation; for (Index h : H.tiles) { - if (!C.array.is_local(h)) continue; + auto const pc = C.permutation; + auto const c = apply(pc, h); + if (!C.array.is_local(c)) continue; size_t batch = 1; for (size_t i = 0; i < h.size(); ++i) { batch *= H.batch[i].at(h[i]); @@ -670,16 +719,8 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, auto &el = tile({k}); using TensorT = std::remove_reference_t; - auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT { - if (l.empty() || r.empty()) return TensorT{}; - return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r, - inner.B, inner.C) - : TA::detail::tensor_contract(l, inner.A, r, - inner.B, inner.C); - }; - for (auto i = 0; i < vol; ++i) - el.add_to(mult_op(aik.data()[i], bik.data()[i])); + el.add_to(element_product_op(aik.data()[i], bik.data()[i])); } else if constexpr (!AreArraySame) { auto aik = ai.batch(k); @@ -702,14 +743,21 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, } } } - auto pc = C.permutation; - auto shape = apply_inverse(pc, C.array.trange().tile(h)); + // data is stored as h1 h2 ... but all modes folded as 1 batch dim + // first reshape to h = (h1 h2 ...) + // n.b. can't just use shape = C.array.trange().tile(h) + auto shape = apply_inverse(pc, C.array.trange().tile(c)); tile = tile.reshape(shape); + // then permute to target C layout c = (c1 c2 ...) if (pc) tile = tile.permute(pc); - C.array.set(h, tile); + // and move to C_local_tiles + C_local_tiles.emplace_back(std::move(c), std::move(tile)); } + + build_C_array(); + return C.array; - } + } // end: hadamard reduction // generalized contraction @@ -740,7 +788,6 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, std::invoke(update_tr, std::get<1>(AB)); std::vector> worlds; - std::vector> local_tiles; // iterates over tiles of hadamard indices for (Index h : H.tiles) { @@ -798,26 +845,13 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, shape = apply_inverse(P, shape); tile = tile.reshape(shape); if (P) tile = tile.permute(P); - local_tiles.push_back({c, tile}); + C_local_tiles.emplace_back(std::move(c), std::move(tile)); } // mark for lazy deletion C.ei = ArrayC(); } - if constexpr (!ResultShape::is_dense()) { - TiledRange tiled_range = TiledRange(range_map[c]); - std::vector> tile_norms; - for (auto &[index, tile] : local_tiles) { - tile_norms.push_back({index, tile.norm()}); - } - ResultShape shape(world, tile_norms, tiled_range); - C.array = ArrayC(world, TiledRange(range_map[c]), shape); - } - - for (auto &[index, tile] : local_tiles) { - if (C.array.is_zero(index)) continue; - C.array.set(index, tile); - } + build_C_array(); for (auto &w : worlds) { w->gop.fence(); diff --git a/src/TiledArray/tensor/kernels.h b/src/TiledArray/tensor/kernels.h index 34e3ea0c9a..fdeb0c77b5 100644 --- a/src/TiledArray/tensor/kernels.h +++ b/src/TiledArray/tensor/kernels.h @@ -1158,69 +1158,82 @@ Scalar tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op, return result; } -/// -/// todo: constraint ResultTensorAllocator type so that non-sensical Allocators -/// are prohibited -/// -template && - is_annotation_v>> -auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B, - Annot const& aB, Annot const& aC) { - using Result = result_tensor_t, TensorA, TensorB, - ResultTensorAllocator>; - - using Indices = ::Einsum::index::Index; - using Permutation = ::Einsum::index::Permutation; - using ::Einsum::index::permutation; - - // Check that the ranks of the tensors match that of the annotation. - TA_ASSERT(A.range().rank() == aA.size()); - TA_ASSERT(B.range().rank() == aB.size()); - - struct { - Indices // - A, // indices of A - B, // indices of B - C, // indices of C (target indices) - h, // Hadamard indices (aA intersection aB intersection aC) - e, // external indices (aA symmetric difference aB) - i; // internal indices ((aA intersection aB) set difference aC) - } const indices{aA, - aB, - aC, - (indices.A & indices.B & indices.C), - (indices.A ^ indices.B), - ((indices.A & indices.B) - indices.h)}; - - TA_ASSERT(!indices.h && "Hadamard indices not supported"); - TA_ASSERT(indices.e && "Dot product not supported"); +/// plan for a binary Tensor contraction of fixed topology +template >> +struct TensorContractionPlan { + using Indices = Einsum::index::Index; + using Permutation = Einsum::index::Permutation; + + const Indices // + A, // indices of A + B, // indices of B + C, // indices of C (target indices) + h, // Hadamard indices (aA intersection aB intersection aC) + e, // external indices (aA symmetric difference aB) + i; // internal indices ((aA intersection aB) set difference aC) struct { Indices A, B, C; - } const blas_layout{(indices.A - indices.B) | indices.i, - indices.i | (indices.B - indices.A), indices.e}; + } const blas_layout; struct { Permutation A, B, C; - } const perm{permutation(indices.A, blas_layout.A), - permutation(indices.B, blas_layout.B), - permutation(indices.C, blas_layout.C)}; + } const perm; struct { bool A, B, C; - } const do_perm{indices.A != blas_layout.A, indices.B != blas_layout.B, - indices.C != blas_layout.C}; + } const do_perm; + + const math::GemmHelper gemm_helper; + + /// constructs plan for contraction C(aC) = A(aA) * B(aB). E.g. + /// `TensorContractionPlan("i,k", "k,j", "i,j")` constructs a plan + /// for matrix product. + /// \param aA einsum annotation for first argument (A) + /// \param aB einsum annotation for second argument (B) + /// \param aC einsum annotation for the result (C) + TensorContractionPlan(Annot const& aA, Annot const& aB, Annot const& aC) + : A(aA), + B(aB), + C(aC), + h(A & B & C), + e(A ^ B), + i((A & B) - h), + blas_layout{(A - B) | i, i | (B - A), e}, + perm{Einsum::index::permutation(A, blas_layout.A), + Einsum::index::permutation(B, blas_layout.B), + Einsum::index::permutation(C, blas_layout.C)}, + do_perm{A != blas_layout.A, B != blas_layout.B, C != blas_layout.C}, + gemm_helper{blas::Op::NoTrans, blas::Op::NoTrans, + static_cast(e.size()), + static_cast(A.size()), + static_cast(B.size())} { + TA_ASSERT(!h && "Hadamard indices not supported"); + TA_ASSERT(e && "Dot product not supported"); + } +}; - math::GemmHelper gemm_helper{blas::Op::NoTrans, blas::Op::NoTrans, - static_cast(indices.e.size()), - static_cast(indices.A.size()), - static_cast(indices.B.size())}; +/// contracts 2 tensors using the given contraction \p plan . +/// @internal TODO constrain ResultTensorAllocator type so that non-sensical +/// Allocators are prohibited +/// @return result of the contraction +template && + is_annotation_v>> +auto tensor_contract(TensorA const& A, TensorB const& B, + const TensorContractionPlan& plan) { + using Result = result_tensor_t, TensorA, TensorB, + ResultTensorAllocator>; + + // Check that the ranks of the tensors match that of the annotation. + TA_ASSERT(A.range().rank() == plan.A.size()); + TA_ASSERT(B.range().rank() == plan.B.size()); // initialize result with the correct extents Result result; { + using Indices = Einsum::index::Index; using Index = typename Indices::value_type; using Extent = std::remove_cv_t< typename decltype(std::declval().extent())::value_type>; @@ -1230,12 +1243,12 @@ auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B, // Note that whether the contracting indices have matching extents is // implicitly checked here by the pipe(|) operator on ExtentMap. - ExtentMap extent = (ExtentMap{indices.A, A.range().extent()} | - ExtentMap{indices.B, B.range().extent()}); + ExtentMap extent = (ExtentMap{plan.A, A.range().extent()} | + ExtentMap{plan.B, B.range().extent()}); container::vector rng; - rng.reserve(indices.e.size()); - for (auto&& ix : indices.e) { + rng.reserve(plan.e.size()); + for (auto&& ix : plan.e) { // assuming ix _exists_ in extent rng.emplace_back(extent[ix]); } @@ -1245,58 +1258,108 @@ auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B, using Numeric = typename Result::numeric_type; // call gemm - gemm(Numeric{1}, // - do_perm.A ? A.permute(perm.A) : A, // - do_perm.B ? B.permute(perm.B) : B, // - Numeric{0}, result, gemm_helper); + gemm(Numeric{1}, // + plan.do_perm.A ? A.permute(plan.perm.A) : A, // + plan.do_perm.B ? B.permute(plan.perm.B) : B, // + Numeric{0}, result, plan.gemm_helper); - return do_perm.C ? result.permute(perm.C.inv()) : result; + return plan.do_perm.C ? result.permute(plan.perm.C.inv()) : result; } -template && is_annotation_v>> -auto tensor_hadamard(TensorA const& A, Annot const& aA, TensorB const& B, +auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B, Annot const& aB, Annot const& aC) { - using ::Einsum::index::Permutation; - using ::Einsum::index::permutation; - using Indices = ::Einsum::index::Index; + using Result = result_tensor_t, TensorA, TensorB, + ResultTensorAllocator>; + + TensorContractionPlan plan(aA, aB, aC); + + return tensor_contract(A, B, plan); +} + +/// plan for Tensor contractions of fixed topology +template >> +struct TensorHadamardPlan { + using Indices = Einsum::index::Index; + using Permutation = Einsum::index::Permutation; + + const Indices // + A, // indices of A + B, // indices of B + C; // indices of C (target indices) struct { Permutation // AB, // permutes A to B AC, // permutes A to C BC; // permutes B to C - } const perm{permutation(Indices(aA), Indices(aB)), - permutation(Indices(aA), Indices(aC)), - permutation(Indices(aB), Indices(aC))}; + } const perm; + + const bool no_perm, perm_to_c, perm_a, perm_b; + + /// constructs plan for generalized hadamard product C(aC) = A(aA) * B(aB). + /// E.g. `TensorHadamardPlan("i,j", "i,j", "j,i")` constructs a plan + /// for product C(j,i) = A(i,j) B (i,j) + /// \param aA einsum annotation for first argument (A) + /// \param aB einsum annotation for second argument (B) + /// \param aC einsum annotation for the result (C) + TensorHadamardPlan(Annot const& aA, Annot const& aB, Annot const& aC) + : A(aA), + B(aB), + C(aC), + perm{Einsum::index::permutation(A, B), Einsum::index::permutation(A, C), + Einsum::index::permutation(B, C)}, + no_perm(perm.AB.is_identity() && perm.AC.is_identity() && + perm.BC.is_identity()), + perm_to_c(perm.AB.is_identity()), + perm_a(perm.BC.is_identity()), // + perm_b(perm.AC.is_identity()) {} +}; - struct { - bool no_perm, perm_to_c, perm_a, perm_b; - } const do_this{ - perm.AB.is_identity() && perm.AC.is_identity() && perm.BC.is_identity(), - perm.AB.is_identity(), // - perm.BC.is_identity(), // - perm.AC.is_identity()}; - - if (do_this.no_perm) { +template && + is_annotation_v>> +auto tensor_hadamard(TensorA const& A, TensorB const& B, + const TensorHadamardPlan& plan) { + // Check that the ranks of the tensors match that of the annotation. + TA_ASSERT(A.range().rank() == plan.A.size()); + TA_ASSERT(B.range().rank() == plan.B.size()); + + if (plan.no_perm) { return A.mult(B); - } else if (do_this.perm_to_c) { - return A.mult(B, perm.AC); - } else if (do_this.perm_a) { - auto pA = A.permute(perm.AC); + } else if (plan.perm_to_c) { + return A.mult(B, plan.perm.AC); + } else if (plan.perm_a) { + auto pA = A.permute(plan.perm.AC); pA.mult_to(B); return pA; - } else if (do_this.perm_b) { - auto pB = B.permute(perm.BC); + } else if (plan.perm_b) { + auto pB = B.permute(plan.perm.BC); pB.mult_to(A); return pB; } else { - auto pA = A.permute(perm.AC); - return pA.mult_to(B.permute(perm.BC)); + auto pA = A.permute(plan.perm.AC); + return pA.mult_to(B.permute(plan.perm.BC)); } } +template && + is_annotation_v>> +auto tensor_hadamard(TensorA const& A, Annot const& aA, TensorB const& B, + Annot const& aB, Annot const& aC) { + TensorHadamardPlan plan(aA, aB, aC); + + return tensor_hadamard(A, B, plan); +} + } // namespace detail } // namespace TiledArray diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 6be4a4a99d..6d32285de2 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -25,7 +25,7 @@ #include "TiledArray/expressions/contraction_helpers.h" -BOOST_AUTO_TEST_SUITE(manual) +BOOST_AUTO_TEST_SUITE(einsum_manual) namespace { using il_trange = std::initializer_list>; @@ -216,18 +216,20 @@ BOOST_AUTO_TEST_CASE(equal_nested_ranks) { {4, 3})); // H+C;H - BOOST_REQUIRE(check_manual_eval("ijk;mn,ijk;nm->ij;mn", // - {{0, 2}, {0, 3}, {0, 2}}, // - {{0, 2}, {0, 3}, {0, 2}}, // - {2, 2}, // - {2, 2})); + BOOST_REQUIRE( + check_manual_eval("jki;mn,ijk;nm->ij;mn", // + {{0, 2, 3}, {0, 3, 5}, {0, 1, 3}}, // + {{0, 1, 3}, {0, 2, 3}, {0, 3, 5}}, // + {3, 2}, // + {2, 3})); // H+C;C - BOOST_REQUIRE(check_manual_eval("ijk;mo,ijk;no->ij;nm", // - {{0, 2}, {0, 3}, {0, 2}}, // - {{0, 2}, {0, 3}, {0, 2}}, // - {3, 2}, // - {3, 2})); + BOOST_REQUIRE( + check_manual_eval("ijk;mo,kji;no->ij;nm", // + {{0, 1, 3}, {0, 2, 3}, {0, 3, 5}}, // + {{0, 3, 5}, {0, 2, 3}, {0, 1, 3}}, // + {3, 2}, // + {4, 2})); // H+C;C BOOST_REQUIRE(check_manual_eval("ijk;m,ijk;n->ij;nm", // @@ -240,6 +242,30 @@ BOOST_AUTO_TEST_CASE(equal_nested_ranks) { // H;C(op) BOOST_REQUIRE(check_manual_eval( "ijk;bc,j;d->kji;dcb", {{0, 1}, {0, 1}, {0, 1}}, {{0, 1}}, {2, 3}, {4})); + + // H+C;C + BOOST_REQUIRE(check_manual_eval( + "jki;ad,ikj;db->ij;ab", // + {{0, 1, 2, 3, 4, 5, 6}, {0, 1, 2, 3}, {0, 1, 2, 3, 4}}, // + {{0, 1, 2, 3, 4}, {0, 1, 2, 3}, {0, 1, 2, 3, 4, 5, 6}}, // + {3, 2}, // + {2, 4})); + + // H+C;C + BOOST_REQUIRE( + check_manual_eval("ijk;mo,kji;no->ik;nm", // + {{0, 3, 6}, {0, 1, 3}, {0, 2, 4}}, // + {{0, 2, 4}, {0, 1, 3}, {0, 3, 6}}, // + {3, 2}, // + {4, 2})); + + // H+C;C + BOOST_REQUIRE( + check_manual_eval("ijk;mo,ijk;no->ji;nm", // + {{0, 2, 5}, {0, 1, 3}, {0, 3, 4}}, // + {{0, 2, 5}, {0, 1, 3}, {0, 3, 4}}, // + {4, 2}, // + {3, 2})); } BOOST_AUTO_TEST_CASE(different_nested_ranks) {