@@ -420,6 +420,9 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
420420 using ResultTensor = typename ArrayC::value_type;
421421 using ResultShape = typename ArrayC::shape_type;
422422
423+ auto const & tnsrExprA = A;
424+ auto const & tnsrExprB = B;
425+
423426 auto a = std::get<0 >(Einsum::idx (A));
424427 auto b = std::get<0 >(Einsum::idx (B));
425428 Einsum::Index<std::string> c = std::get<0 >(cs);
@@ -536,16 +539,10 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
536539 // the evaluation can be delegated to the expression layer
537540 // for distarrays of both nested and non-nested tensor tiles.
538541 // *) If no Hadamard indices are present (!h) the evaluation
539- // can be delegated to the expression _only_ for distarrays with
540- // non-nested tensor tiles.
541- // This is because even if Hadamard indices are not present, a contracted
542- // index might be present pertinent to the outer tensor in case of a
543- // nested-tile distarray, which is especially handled within this
544- // function because expression layer cannot handle that yet.
542+ // can be delegated to the expression layer.
545543 //
546- if ((h && !(i || e)) // pure Hadamard
547- || (IsArrayToT<ArrayC> && !(i || h)) // ToT result from outer-product
548- || (IsArrayT<ArrayC> && !h)) // T from general product without Hadamard
544+ if ((h && !(i || e)) // pure Hadamard
545+ || !h) // no Hadamard
549546 {
550547 ArrayC C;
551548 C (std::string (c) + inner.c ) = A * B;
@@ -577,21 +574,6 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
577574 return C;
578575 }
579576
580- //
581- // when contraction happens in the outer tensor
582- // need to evaluate specially..
583- //
584- if (IsArrayToT<ArrayC> && i.size () > 0 ) {
585- auto annot_c = std::string (h + e + i) + inner.c ;
586- auto temp1 = einsum (A, B, idx<ArrayC>(annot_c), world);
587- auto temp2 = reduce_modes (temp1, i.size ());
588-
589- auto annot_c_ = std::string (h + e) + inner.c ;
590- decltype (temp2) result;
591- result (std::string (c) + inner.c ) = temp2 (annot_c_);
592- return result;
593- }
594-
595577 using ::Einsum::index::permutation;
596578 using TiledArray::Permutation;
597579
@@ -640,79 +622,103 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
640622
641623 using Index = Einsum::Index<size_t >;
642624
643- if constexpr (AreArraySame<ArrayA, ArrayB> &&
644- AreArraySame<ArrayB, ArrayC>) {
645- if (!e) { // hadamard reduction
646- auto &[A, B] = AB;
647- TiledRange trange (range_map[i]);
648- RangeProduct tiles;
649- for (auto idx : i) {
650- tiles *= Range (range_map[idx].tiles_range ());
625+ if (!e) { // hadamard reduction
626+ auto &[A, B] = AB;
627+ TiledRange trange (range_map[i]);
628+ RangeProduct tiles;
629+ for (auto idx : i) {
630+ tiles *= Range (range_map[idx].tiles_range ());
631+ }
632+ auto pa = A.permutation ;
633+ auto pb = B.permutation ;
634+ for (Index h : H.tiles ) {
635+ if (!C.array .is_local (h)) continue ;
636+ size_t batch = 1 ;
637+ for (size_t i = 0 ; i < h.size (); ++i) {
638+ batch *= H.batch [i].at (h[i]);
651639 }
652- auto pa = A.permutation ;
653- auto pb = B.permutation ;
654- for (Index h : H.tiles ) {
655- if (!C.array .is_local (h)) continue ;
656- size_t batch = 1 ;
657- for (size_t i = 0 ; i < h.size (); ++i) {
658- batch *= H.batch [i].at (h[i]);
659- }
660- ResultTensor tile (TiledArray::Range{batch},
661- typename ResultTensor::value_type{});
662- for (Index i : tiles) {
663- // skip this unless both input tiles exist
664- const auto pahi_inv = apply_inverse (pa, h + i);
665- const auto pbhi_inv = apply_inverse (pb, h + i);
666- if (A.array .is_zero (pahi_inv) || B.array .is_zero (pbhi_inv))
667- continue ;
668-
669- auto ai = A.array .find (pahi_inv).get ();
670- auto bi = B.array .find (pbhi_inv).get ();
671- if (pa) ai = ai.permute (pa);
672- if (pb) bi = bi.permute (pb);
673- auto shape = trange.tile (i);
674- ai = ai.reshape (shape, batch);
675- bi = bi.reshape (shape, batch);
676- for (size_t k = 0 ; k < batch; ++k) {
677- using Ix = ::Einsum::Index<std::string>;
678- if constexpr (AreArrayToT<ArrayA, ArrayB>) {
679- auto aik = ai.batch (k);
680- auto bik = bi.batch (k);
681- auto vol = aik.total_size ();
682- TA_ASSERT (vol == bik.total_size ());
683-
684- auto &el = tile ({k});
685- using TensorT = std::remove_reference_t <decltype (el)>;
686-
687- auto mult_op = [&inner](auto const &l,
688- auto const &r) -> TensorT {
689- return inner.h ? TA::detail::tensor_hadamard (l, inner.A , r,
690- inner.B , inner.C )
691- : TA::detail::tensor_contract (
692- l, inner.A , r, inner.B , inner.C );
693- };
694-
695- for (auto i = 0 ; i < vol; ++i)
696- el.add_to (mult_op (aik.data ()[i], bik.data ()[i]));
697-
698- } else {
699- auto hk = ai.batch (k).dot (bi.batch (k));
700- tile ({k}) += hk;
701- }
640+ ResultTensor tile (TiledArray::Range{batch},
641+ typename ResultTensor::value_type{});
642+ for (Index i : tiles) {
643+ // skip this unless both input tiles exist
644+ const auto pahi_inv = apply_inverse (pa, h + i);
645+ const auto pbhi_inv = apply_inverse (pb, h + i);
646+ if (A.array .is_zero (pahi_inv) || B.array .is_zero (pbhi_inv)) continue ;
647+
648+ auto ai = A.array .find (pahi_inv).get ();
649+ auto bi = B.array .find (pbhi_inv).get ();
650+ if (pa) ai = ai.permute (pa);
651+ if (pb) bi = bi.permute (pb);
652+ auto shape = trange.tile (i);
653+ ai = ai.reshape (shape, batch);
654+ bi = bi.reshape (shape, batch);
655+ for (size_t k = 0 ; k < batch; ++k) {
656+ using Ix = ::Einsum::Index<std::string>;
657+ if constexpr (AreArrayToT<ArrayA, ArrayB>) {
658+ auto aik = ai.batch (k);
659+ auto bik = bi.batch (k);
660+ auto vol = aik.total_size ();
661+ TA_ASSERT (vol == bik.total_size ());
662+
663+ auto &el = tile ({k});
664+ using TensorT = std::remove_reference_t <decltype (el)>;
665+
666+ auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT {
667+ return inner.h ? TA::detail::tensor_hadamard (l, inner.A , r,
668+ inner.B , inner.C )
669+ : TA::detail::tensor_contract (l, inner.A , r,
670+ inner.B , inner.C );
671+ };
672+
673+ for (auto i = 0 ; i < vol; ++i)
674+ el.add_to (mult_op (aik.data ()[i], bik.data ()[i]));
675+
676+ } else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
677+ auto aik = ai.batch (k);
678+ auto bik = bi.batch (k);
679+ auto vol = aik.total_size ();
680+ TA_ASSERT (vol == bik.total_size ());
681+
682+ auto &el = tile ({k});
683+
684+ for (auto i = 0 ; i < vol; ++i)
685+ if constexpr (IsArrayToT<ArrayA>) {
686+ el.add_to (aik.data ()[i].scale (bik.data ()[i]));
687+ } else {
688+ el.add_to (bik.data ()[i].scale (aik.data ()[i]));
689+ }
690+
691+ } else {
692+ auto hk = ai.batch (k).dot (bi.batch (k));
693+ tile ({k}) += hk;
702694 }
703695 }
704- auto pc = C.permutation ;
705- auto shape = apply_inverse (pc, C.array .trange ().tile (h));
706- tile = tile.reshape (shape);
707- if (pc) tile = tile.permute (pc);
708- C.array .set (h, tile);
709696 }
710- return C.array ;
697+ auto pc = C.permutation ;
698+ auto shape = apply_inverse (pc, C.array .trange ().tile (h));
699+ tile = tile.reshape (shape);
700+ if (pc) tile = tile.permute (pc);
701+ C.array .set (h, tile);
711702 }
703+ return C.array ;
712704 }
713705
714706 // generalized contraction
715707
708+ if constexpr (IsArrayToT<ArrayC>) {
709+ if (inner.C != inner.h + inner.e ) {
710+ // when inner tensor permutation is non-trivial (could be potentially
711+ // elided by extending this function (@c einsum) to take into account
712+ // of inner tensor's permutations)
713+ auto temp_annot = std::string (c) + " ;" + std::string (inner.h + inner.e );
714+ ArrayC temp = einsum (tnsrExprA, tnsrExprB,
715+ Einsum::idx<ArrayC>(temp_annot), world);
716+ ArrayC result;
717+ result (std::string (c) + inner.c ) = temp (temp_annot);
718+ return result;
719+ }
720+ }
721+
716722 auto update_tr = [&e = std::as_const (e), &i = std::as_const (i),
717723 &range_map = std::as_const (range_map)](auto &term) {
718724 auto ei = (e + i & term.idx );
0 commit comments