Skip to content

Commit abb7886

Browse files
committed
More efficient contractions involving modes of the outer tensor in ToT involving expressions.
1 parent 87664ae commit abb7886

File tree

4 files changed

+142
-96
lines changed

4 files changed

+142
-96
lines changed

src/TiledArray/einsum/tiledarray.h

Lines changed: 94 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,9 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
420420
using ResultTensor = typename ArrayC::value_type;
421421
using ResultShape = typename ArrayC::shape_type;
422422

423+
auto const& tnsrExprA = A;
424+
auto const& tnsrExprB = B;
425+
423426
auto a = std::get<0>(Einsum::idx(A));
424427
auto b = std::get<0>(Einsum::idx(B));
425428
Einsum::Index<std::string> c = std::get<0>(cs);
@@ -536,16 +539,10 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
536539
// the evaluation can be delegated to the expression layer
537540
// for distarrays of both nested and non-nested tensor tiles.
538541
// *) If no Hadamard indices are present (!h) the evaluation
539-
// can be delegated to the expression _only_ for distarrays with
540-
// non-nested tensor tiles.
541-
// This is because even if Hadamard indices are not present, a contracted
542-
// index might be present pertinent to the outer tensor in case of a
543-
// nested-tile distarray, which is especially handled within this
544-
// function because expression layer cannot handle that yet.
542+
// can be delegated to the expression layer.
545543
//
546-
if ((h && !(i || e)) // pure Hadamard
547-
|| (IsArrayToT<ArrayC> && !(i || h)) // ToT result from outer-product
548-
|| (IsArrayT<ArrayC> && !h)) // T from general product without Hadamard
544+
if ((h && !(i || e)) // pure Hadamard
545+
|| !h) // no Hadamard
549546
{
550547
ArrayC C;
551548
C(std::string(c) + inner.c) = A * B;
@@ -577,21 +574,6 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
577574
return C;
578575
}
579576

580-
//
581-
// when contraction happens in the outer tensor
582-
// need to evaluate specially..
583-
//
584-
if (IsArrayToT<ArrayC> && i.size() > 0) {
585-
auto annot_c = std::string(h + e + i) + inner.c;
586-
auto temp1 = einsum(A, B, idx<ArrayC>(annot_c), world);
587-
auto temp2 = reduce_modes(temp1, i.size());
588-
589-
auto annot_c_ = std::string(h + e) + inner.c;
590-
decltype(temp2) result;
591-
result(std::string(c) + inner.c) = temp2(annot_c_);
592-
return result;
593-
}
594-
595577
using ::Einsum::index::permutation;
596578
using TiledArray::Permutation;
597579

@@ -640,79 +622,103 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
640622

641623
using Index = Einsum::Index<size_t>;
642624

643-
if constexpr (AreArraySame<ArrayA, ArrayB> &&
644-
AreArraySame<ArrayB, ArrayC>) {
645-
if (!e) { // hadamard reduction
646-
auto &[A, B] = AB;
647-
TiledRange trange(range_map[i]);
648-
RangeProduct tiles;
649-
for (auto idx : i) {
650-
tiles *= Range(range_map[idx].tiles_range());
625+
if (!e) { // hadamard reduction
626+
auto &[A, B] = AB;
627+
TiledRange trange(range_map[i]);
628+
RangeProduct tiles;
629+
for (auto idx : i) {
630+
tiles *= Range(range_map[idx].tiles_range());
631+
}
632+
auto pa = A.permutation;
633+
auto pb = B.permutation;
634+
for (Index h : H.tiles) {
635+
if (!C.array.is_local(h)) continue;
636+
size_t batch = 1;
637+
for (size_t i = 0; i < h.size(); ++i) {
638+
batch *= H.batch[i].at(h[i]);
651639
}
652-
auto pa = A.permutation;
653-
auto pb = B.permutation;
654-
for (Index h : H.tiles) {
655-
if (!C.array.is_local(h)) continue;
656-
size_t batch = 1;
657-
for (size_t i = 0; i < h.size(); ++i) {
658-
batch *= H.batch[i].at(h[i]);
659-
}
660-
ResultTensor tile(TiledArray::Range{batch},
661-
typename ResultTensor::value_type{});
662-
for (Index i : tiles) {
663-
// skip this unless both input tiles exist
664-
const auto pahi_inv = apply_inverse(pa, h + i);
665-
const auto pbhi_inv = apply_inverse(pb, h + i);
666-
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv))
667-
continue;
668-
669-
auto ai = A.array.find(pahi_inv).get();
670-
auto bi = B.array.find(pbhi_inv).get();
671-
if (pa) ai = ai.permute(pa);
672-
if (pb) bi = bi.permute(pb);
673-
auto shape = trange.tile(i);
674-
ai = ai.reshape(shape, batch);
675-
bi = bi.reshape(shape, batch);
676-
for (size_t k = 0; k < batch; ++k) {
677-
using Ix = ::Einsum::Index<std::string>;
678-
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
679-
auto aik = ai.batch(k);
680-
auto bik = bi.batch(k);
681-
auto vol = aik.total_size();
682-
TA_ASSERT(vol == bik.total_size());
683-
684-
auto &el = tile({k});
685-
using TensorT = std::remove_reference_t<decltype(el)>;
686-
687-
auto mult_op = [&inner](auto const &l,
688-
auto const &r) -> TensorT {
689-
return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r,
690-
inner.B, inner.C)
691-
: TA::detail::tensor_contract(
692-
l, inner.A, r, inner.B, inner.C);
693-
};
694-
695-
for (auto i = 0; i < vol; ++i)
696-
el.add_to(mult_op(aik.data()[i], bik.data()[i]));
697-
698-
} else {
699-
auto hk = ai.batch(k).dot(bi.batch(k));
700-
tile({k}) += hk;
701-
}
640+
ResultTensor tile(TiledArray::Range{batch},
641+
typename ResultTensor::value_type{});
642+
for (Index i : tiles) {
643+
// skip this unless both input tiles exist
644+
const auto pahi_inv = apply_inverse(pa, h + i);
645+
const auto pbhi_inv = apply_inverse(pb, h + i);
646+
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;
647+
648+
auto ai = A.array.find(pahi_inv).get();
649+
auto bi = B.array.find(pbhi_inv).get();
650+
if (pa) ai = ai.permute(pa);
651+
if (pb) bi = bi.permute(pb);
652+
auto shape = trange.tile(i);
653+
ai = ai.reshape(shape, batch);
654+
bi = bi.reshape(shape, batch);
655+
for (size_t k = 0; k < batch; ++k) {
656+
using Ix = ::Einsum::Index<std::string>;
657+
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
658+
auto aik = ai.batch(k);
659+
auto bik = bi.batch(k);
660+
auto vol = aik.total_size();
661+
TA_ASSERT(vol == bik.total_size());
662+
663+
auto &el = tile({k});
664+
using TensorT = std::remove_reference_t<decltype(el)>;
665+
666+
auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT {
667+
return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r,
668+
inner.B, inner.C)
669+
: TA::detail::tensor_contract(l, inner.A, r,
670+
inner.B, inner.C);
671+
};
672+
673+
for (auto i = 0; i < vol; ++i)
674+
el.add_to(mult_op(aik.data()[i], bik.data()[i]));
675+
676+
} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
677+
auto aik = ai.batch(k);
678+
auto bik = bi.batch(k);
679+
auto vol = aik.total_size();
680+
TA_ASSERT(vol == bik.total_size());
681+
682+
auto &el = tile({k});
683+
684+
for (auto i = 0; i < vol; ++i)
685+
if constexpr (IsArrayToT<ArrayA>) {
686+
el.add_to(aik.data()[i].scale(bik.data()[i]));
687+
} else {
688+
el.add_to(bik.data()[i].scale(aik.data()[i]));
689+
}
690+
691+
} else {
692+
auto hk = ai.batch(k).dot(bi.batch(k));
693+
tile({k}) += hk;
702694
}
703695
}
704-
auto pc = C.permutation;
705-
auto shape = apply_inverse(pc, C.array.trange().tile(h));
706-
tile = tile.reshape(shape);
707-
if (pc) tile = tile.permute(pc);
708-
C.array.set(h, tile);
709696
}
710-
return C.array;
697+
auto pc = C.permutation;
698+
auto shape = apply_inverse(pc, C.array.trange().tile(h));
699+
tile = tile.reshape(shape);
700+
if (pc) tile = tile.permute(pc);
701+
C.array.set(h, tile);
711702
}
703+
return C.array;
712704
}
713705

714706
// generalized contraction
715707

708+
if constexpr (IsArrayToT<ArrayC>) {
709+
if (inner.C != inner.h + inner.e) {
710+
// when inner tensor permutation is non-trivial (could be potentially
711+
// elided by extending this function (@c einsum) to take into account
712+
// of inner tensor's permutations)
713+
auto temp_annot = std::string(c) + ";" + std::string(inner.h + inner.e);
714+
ArrayC temp = einsum(tnsrExprA, tnsrExprB,
715+
Einsum::idx<ArrayC>(temp_annot), world);
716+
ArrayC result;
717+
result(std::string(c) + inner.c) = temp(temp_annot);
718+
return result;
719+
}
720+
}
721+
716722
auto update_tr = [&e = std::as_const(e), &i = std::as_const(i),
717723
&range_map = std::as_const(range_map)](auto &term) {
718724
auto ei = (e + i & term.idx);

src/TiledArray/expressions/cont_engine.h

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,25 +279,62 @@ class ContEngine : public BinaryEngine<Derived> {
279279
outer_size(left_indices_), outer_size(right_indices_),
280280
(!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{}));
281281
} else {
282+
283+
auto make_total_perm = [this]() -> BipartitePermutation {
284+
if (this->product_type() != TensorProduct::Contraction
285+
|| this->implicit_permute_inner_)
286+
return this->implicit_permute_outer_
287+
? BipartitePermutation()
288+
: BipartitePermutation(outer(this->perm_));
289+
290+
// Here,
291+
// this->product_type() is Tensor::Contraction, and,
292+
// this->implicit_permute_inner_ is false
293+
294+
return this->inner_product_type() == TensorProduct::Scale
295+
? BipartitePermutation(outer(this->perm_))
296+
: this->perm_;
297+
};
298+
299+
auto total_perm = make_total_perm();
300+
282301
// factor_ is absorbed into inner_tile_nonreturn_op_
283302
op_ = op_type(
284303
left_op, right_op, scalar_type(1), outer_size(indices_),
285304
outer_size(left_indices_), outer_size(right_indices_),
286-
(!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{}),
305+
total_perm,
287306
this->element_nonreturn_op_);
288307
}
289308
trange_ = ContEngine_::make_trange(outer_perm);
290309
shape_ = ContEngine_::make_shape(outer_perm);
291310
} else {
292311
// Initialize non-permuted structure
312+
293313
if constexpr (!TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
294314
op_ = op_type(left_op, right_op, factor_, outer_size(indices_),
295315
outer_size(left_indices_), outer_size(right_indices_));
296316
} else {
317+
318+
auto make_total_perm = [this]() -> BipartitePermutation {
319+
if (this->product_type() != TensorProduct::Contraction
320+
|| this->implicit_permute_inner_)
321+
return {};
322+
323+
// Here,
324+
// this->product_type() is Tensor::Contraction, and,
325+
// this->implicit_permute_inner_ is false
326+
327+
return this->inner_product_type() == TensorProduct::Scale
328+
? BipartitePermutation(outer(this->perm_))
329+
: this->perm_;
330+
};
331+
332+
auto total_perm = make_total_perm();
333+
297334
// factor_ is absorbed into inner_tile_nonreturn_op_
298335
op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_),
299336
outer_size(left_indices_), outer_size(right_indices_),
300-
BipartitePermutation{}, this->element_nonreturn_op_);
337+
total_perm, this->element_nonreturn_op_);
301338
}
302339
trange_ = ContEngine_::make_trange();
303340
shape_ = ContEngine_::make_shape();
@@ -509,12 +546,15 @@ class ContEngine : public BinaryEngine<Derived> {
509546
inner_size(this->left_indices_),
510547
inner_size(this->right_indices_));
511548
this->element_nonreturn_op_ =
512-
[contrreduce_op](result_tile_element_type& result,
513-
const left_tile_element_type& left,
514-
const right_tile_element_type& right) {
549+
[contrreduce_op, permute_inner = this->product_type() !=
550+
TensorProduct::Contraction](
551+
result_tile_element_type& result,
552+
const left_tile_element_type& left,
553+
const right_tile_element_type& right) {
515554
contrreduce_op(result, left, right);
516-
if (!TA::empty(result))
517-
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
555+
// permutations of result are applied as "postprocessing"
556+
if (permute_inner && !TA::empty(result))
557+
result = contrreduce_op(result);
518558
};
519559
} // ToT x ToT
520560
} else if (inner_prod == TensorProduct::Hadamard) {

src/TiledArray/tensor/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,7 @@ class Tensor {
16301630
template <typename Right,
16311631
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
16321632
Tensor add(const Right& right) const& {
1633+
if (right.empty()) return *this;
16331634
return binary(
16341635
right,
16351636
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {

src/TiledArray/tile_op/contract_reduce.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ class ContractReduce : public ContractReduceBase<Result, Left, Right, Scalar> {
332332

333333
if constexpr (!ContractReduceBase_::plain_tensors) {
334334
TA_ASSERT(this->elem_muladd_op());
335-
// not yet implemented
336335
gemm(result, left, right, ContractReduceBase_::gemm_helper(),
337336
this->elem_muladd_op());
338337
} else { // plain tensors

0 commit comments

Comments
 (0)