Skip to content

Commit ec56aa0

Browse files
authored
Merge pull request #512 from ValeevGroup/evaleev/fix/einsum-hadamard-reduction
fix einsum hadamard reduction
2 parents e2d2e7a + 8008ced commit ec56aa0

File tree

4 files changed

+250
-127
lines changed

4 files changed

+250
-127
lines changed

src/TiledArray/conversions/make_array.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ inline Array make_array(World& world, const detail::trange_t<Array>& trange,
242242
op);
243243
}
244244

245-
/// a make_array variant that uses a sequence of tiles
245+
/// a make_array variant that uses a sequence of {tile_index,tile} pairs
246246
/// to construct a DistArray with default pmap
247247
template <typename Array, typename Tiles>
248248
Array make_array(World& world, const detail::trange_t<Array>& tiled_range,

src/TiledArray/einsum/tiledarray.h

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,8 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
600600
std::invoke(update_perm_and_indices, std::get<0>(AB));
601601
std::invoke(update_perm_and_indices, std::get<1>(AB));
602602

603+
// construct result, with "dense" DistArray; the array will be
604+
// reconstructred from local tiles later
603605
ArrayTerm<ArrayC> C = {ArrayC(world, TiledRange(range_map[c])), c};
604606
for (auto idx : e) {
605607
C.tiles *= Range(range_map[idx].tiles_range());
@@ -609,6 +611,16 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
609611
}
610612
C.expr = e;
611613

614+
using Index = Einsum::Index<size_t>;
615+
616+
// this will collect local tiles of C.array, to be used to rebuild C.array
617+
std::vector<std::pair<Index, ResultTensor>> C_local_tiles;
618+
auto build_C_array = [&]() {
619+
C.array = make_array<ArrayC>(world, TiledRange(range_map[c]),
620+
C_local_tiles.begin(), C_local_tiles.end(),
621+
/* replicated = */ false);
622+
};
623+
612624
std::get<0>(AB).expr += inner.a;
613625
std::get<1>(AB).expr += inner.b;
614626

@@ -627,19 +639,56 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
627639
}
628640
}
629641

630-
using Index = Einsum::Index<size_t>;
631-
632642
if (!e) { // hadamard reduction
643+
633644
auto &[A, B] = AB;
634645
TiledRange trange(range_map[i]);
635646
RangeProduct tiles;
636647
for (auto idx : i) {
637648
tiles *= Range(range_map[idx].tiles_range());
638649
}
650+
651+
// the inner product can be either hadamard or a contraction
652+
using TensorT = typename decltype(A.array)::value_type::value_type;
653+
static_assert(
654+
std::is_same_v<TensorT,
655+
typename decltype(A.array)::value_type::value_type>);
656+
constexpr bool is_tot = detail::is_tensor_v<TensorT>;
657+
auto element_hadamard_op =
658+
(is_tot && inner.h)
659+
? std::make_optional(
660+
[&inner, plan = detail::TensorHadamardPlan(inner.A, inner.B,
661+
inner.C)](
662+
auto const &l, auto const &r) -> TensorT {
663+
if (l.empty() || r.empty()) return TensorT{};
664+
return detail::tensor_hadamard(l, r, plan);
665+
})
666+
: std::nullopt;
667+
auto element_contract_op =
668+
(is_tot && !inner.h)
669+
? std::make_optional(
670+
[&inner, plan = detail::TensorContractionPlan(
671+
inner.A, inner.B, inner.C)](
672+
auto const &l, auto const &r) -> TensorT {
673+
if (l.empty() || r.empty()) return TensorT{};
674+
return detail::tensor_contract(l, r, plan);
675+
})
676+
: std::nullopt;
677+
auto element_product_op = [&inner, &element_hadamard_op,
678+
&element_contract_op](
679+
auto const &l, auto const &r) -> TensorT {
680+
TA_ASSERT(inner.h ? element_hadamard_op.has_value()
681+
: element_contract_op.has_value());
682+
return inner.h ? element_hadamard_op.value()(l, r)
683+
: element_contract_op.value()(l, r);
684+
};
685+
639686
auto pa = A.permutation;
640687
auto pb = B.permutation;
641688
for (Index h : H.tiles) {
642-
if (!C.array.is_local(h)) continue;
689+
auto const pc = C.permutation;
690+
auto const c = apply(pc, h);
691+
if (!C.array.is_local(c)) continue;
643692
size_t batch = 1;
644693
for (size_t i = 0; i < h.size(); ++i) {
645694
batch *= H.batch[i].at(h[i]);
@@ -670,16 +719,8 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
670719
auto &el = tile({k});
671720
using TensorT = std::remove_reference_t<decltype(el)>;
672721

673-
auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT {
674-
if (l.empty() || r.empty()) return TensorT{};
675-
return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r,
676-
inner.B, inner.C)
677-
: TA::detail::tensor_contract(l, inner.A, r,
678-
inner.B, inner.C);
679-
};
680-
681722
for (auto i = 0; i < vol; ++i)
682-
el.add_to(mult_op(aik.data()[i], bik.data()[i]));
723+
el.add_to(element_product_op(aik.data()[i], bik.data()[i]));
683724

684725
} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
685726
auto aik = ai.batch(k);
@@ -702,14 +743,21 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
702743
}
703744
}
704745
}
705-
auto pc = C.permutation;
706-
auto shape = apply_inverse(pc, C.array.trange().tile(h));
746+
// data is stored as h1 h2 ... but all modes folded as 1 batch dim
747+
// first reshape to h = (h1 h2 ...)
748+
// n.b. can't just use shape = C.array.trange().tile(h)
749+
auto shape = apply_inverse(pc, C.array.trange().tile(c));
707750
tile = tile.reshape(shape);
751+
// then permute to target C layout c = (c1 c2 ...)
708752
if (pc) tile = tile.permute(pc);
709-
C.array.set(h, tile);
753+
// and move to C_local_tiles
754+
C_local_tiles.emplace_back(std::move(c), std::move(tile));
710755
}
756+
757+
build_C_array();
758+
711759
return C.array;
712-
}
760+
} // end: hadamard reduction
713761

714762
// generalized contraction
715763

@@ -740,7 +788,6 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
740788
std::invoke(update_tr, std::get<1>(AB));
741789

742790
std::vector<std::shared_ptr<World>> worlds;
743-
std::vector<std::tuple<Index, ResultTensor>> local_tiles;
744791

745792
// iterates over tiles of hadamard indices
746793
for (Index h : H.tiles) {
@@ -798,26 +845,13 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
798845
shape = apply_inverse(P, shape);
799846
tile = tile.reshape(shape);
800847
if (P) tile = tile.permute(P);
801-
local_tiles.push_back({c, tile});
848+
C_local_tiles.emplace_back(std::move(c), std::move(tile));
802849
}
803850
// mark for lazy deletion
804851
C.ei = ArrayC();
805852
}
806853

807-
if constexpr (!ResultShape::is_dense()) {
808-
TiledRange tiled_range = TiledRange(range_map[c]);
809-
std::vector<std::pair<Index, float>> tile_norms;
810-
for (auto &[index, tile] : local_tiles) {
811-
tile_norms.push_back({index, tile.norm()});
812-
}
813-
ResultShape shape(world, tile_norms, tiled_range);
814-
C.array = ArrayC(world, TiledRange(range_map[c]), shape);
815-
}
816-
817-
for (auto &[index, tile] : local_tiles) {
818-
if (C.array.is_zero(index)) continue;
819-
C.array.set(index, tile);
820-
}
854+
build_C_array();
821855

822856
for (auto &w : worlds) {
823857
w->gop.fence();

0 commit comments

Comments
 (0)