@@ -600,6 +600,8 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
600600 std::invoke (update_perm_and_indices, std::get<0 >(AB));
601601 std::invoke (update_perm_and_indices, std::get<1 >(AB));
602602
603+ // construct result, with "dense" DistArray; the array will be
604+ // reconstructred from local tiles later
603605 ArrayTerm<ArrayC> C = {ArrayC (world, TiledRange (range_map[c])), c};
604606 for (auto idx : e) {
605607 C.tiles *= Range (range_map[idx].tiles_range ());
@@ -609,6 +611,16 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
609611 }
610612 C.expr = e;
611613
614+ using Index = Einsum::Index<size_t >;
615+
616+ // this will collect local tiles of C.array, to be used to rebuild C.array
617+ std::vector<std::pair<Index, ResultTensor>> C_local_tiles;
618+ auto build_C_array = [&]() {
619+ C.array = make_array<ArrayC>(world, TiledRange (range_map[c]),
620+ C_local_tiles.begin (), C_local_tiles.end (),
621+ /* replicated = */ false );
622+ };
623+
612624 std::get<0 >(AB).expr += inner.a ;
613625 std::get<1 >(AB).expr += inner.b ;
614626
@@ -627,19 +639,56 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
627639 }
628640 }
629641
630- using Index = Einsum::Index<size_t >;
631-
632642 if (!e) { // hadamard reduction
643+
633644 auto &[A, B] = AB;
634645 TiledRange trange (range_map[i]);
635646 RangeProduct tiles;
636647 for (auto idx : i) {
637648 tiles *= Range (range_map[idx].tiles_range ());
638649 }
650+
651+ // the inner product can be either hadamard or a contraction
652+ using TensorT = typename decltype (A.array )::value_type::value_type;
653+ static_assert (
654+ std::is_same_v<TensorT,
655+ typename decltype (A.array )::value_type::value_type>);
656+ constexpr bool is_tot = detail::is_tensor_v<TensorT>;
657+ auto element_hadamard_op =
658+ (is_tot && inner.h )
659+ ? std::make_optional (
660+ [&inner, plan = detail::TensorHadamardPlan (inner.A , inner.B ,
661+ inner.C )](
662+ auto const &l, auto const &r) -> TensorT {
663+ if (l.empty () || r.empty ()) return TensorT{};
664+ return detail::tensor_hadamard (l, r, plan);
665+ })
666+ : std::nullopt ;
667+ auto element_contract_op =
668+ (is_tot && !inner.h )
669+ ? std::make_optional (
670+ [&inner, plan = detail::TensorContractionPlan (
671+ inner.A , inner.B , inner.C )](
672+ auto const &l, auto const &r) -> TensorT {
673+ if (l.empty () || r.empty ()) return TensorT{};
674+ return detail::tensor_contract (l, r, plan);
675+ })
676+ : std::nullopt ;
677+ auto element_product_op = [&inner, &element_hadamard_op,
678+ &element_contract_op](
679+ auto const &l, auto const &r) -> TensorT {
680+ TA_ASSERT (inner.h ? element_hadamard_op.has_value ()
681+ : element_contract_op.has_value ());
682+ return inner.h ? element_hadamard_op.value ()(l, r)
683+ : element_contract_op.value ()(l, r);
684+ };
685+
639686 auto pa = A.permutation ;
640687 auto pb = B.permutation ;
641688 for (Index h : H.tiles ) {
642- if (!C.array .is_local (h)) continue ;
689+ auto const pc = C.permutation ;
690+ auto const c = apply (pc, h);
691+ if (!C.array .is_local (c)) continue ;
643692 size_t batch = 1 ;
644693 for (size_t i = 0 ; i < h.size (); ++i) {
645694 batch *= H.batch [i].at (h[i]);
@@ -670,16 +719,8 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
670719 auto &el = tile ({k});
671720 using TensorT = std::remove_reference_t <decltype (el)>;
672721
673- auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT {
674- if (l.empty () || r.empty ()) return TensorT{};
675- return inner.h ? TA::detail::tensor_hadamard (l, inner.A , r,
676- inner.B , inner.C )
677- : TA::detail::tensor_contract (l, inner.A , r,
678- inner.B , inner.C );
679- };
680-
681722 for (auto i = 0 ; i < vol; ++i)
682- el.add_to (mult_op (aik.data ()[i], bik.data ()[i]));
723+ el.add_to (element_product_op (aik.data ()[i], bik.data ()[i]));
683724
684725 } else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
685726 auto aik = ai.batch (k);
@@ -702,14 +743,21 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
702743 }
703744 }
704745 }
705- auto pc = C.permutation ;
706- auto shape = apply_inverse (pc, C.array .trange ().tile (h));
746+ // data is stored as h1 h2 ... but all modes folded as 1 batch dim
747+ // first reshape to h = (h1 h2 ...)
748+ // n.b. can't just use shape = C.array.trange().tile(h)
749+ auto shape = apply_inverse (pc, C.array .trange ().tile (c));
707750 tile = tile.reshape (shape);
751+ // then permute to target C layout c = (c1 c2 ...)
708752 if (pc) tile = tile.permute (pc);
709- C.array .set (h, tile);
753+ // and move to C_local_tiles
754+ C_local_tiles.emplace_back (std::move (c), std::move (tile));
710755 }
756+
757+ build_C_array ();
758+
711759 return C.array ;
712- }
760+ } // end: hadamard reduction
713761
714762 // generalized contraction
715763
@@ -740,7 +788,6 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
740788 std::invoke (update_tr, std::get<1 >(AB));
741789
742790 std::vector<std::shared_ptr<World>> worlds;
743- std::vector<std::tuple<Index, ResultTensor>> local_tiles;
744791
745792 // iterates over tiles of hadamard indices
746793 for (Index h : H.tiles ) {
@@ -798,26 +845,13 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
798845 shape = apply_inverse (P, shape);
799846 tile = tile.reshape (shape);
800847 if (P) tile = tile.permute (P);
801- local_tiles. push_back ({c, tile} );
848+ C_local_tiles. emplace_back ( std::move (c), std::move ( tile) );
802849 }
803850 // mark for lazy deletion
804851 C.ei = ArrayC ();
805852 }
806853
807- if constexpr (!ResultShape::is_dense ()) {
808- TiledRange tiled_range = TiledRange (range_map[c]);
809- std::vector<std::pair<Index, float >> tile_norms;
810- for (auto &[index, tile] : local_tiles) {
811- tile_norms.push_back ({index, tile.norm ()});
812- }
813- ResultShape shape (world, tile_norms, tiled_range);
814- C.array = ArrayC (world, TiledRange (range_map[c]), shape);
815- }
816-
817- for (auto &[index, tile] : local_tiles) {
818- if (C.array .is_zero (index)) continue ;
819- C.array .set (index, tile);
820- }
854+ build_C_array ();
821855
822856 for (auto &w : worlds) {
823857 w->gop .fence ();
0 commit comments