@@ -1158,7 +1158,7 @@ Scalar tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op,
11581158 return result;
11591159}
11601160
1161- // / plan for Tensor contractions of fixed topology
1161+ // / plan for a binary Tensor contraction of fixed topology
11621162template <typename Annot, typename = std::enable_if_t <is_annotation_v<Annot>>>
11631163struct TensorContractionPlan {
11641164 using Indices = Einsum::index::Index<typename Annot::value_type>;
@@ -1186,6 +1186,12 @@ struct TensorContractionPlan {
11861186
11871187 const math::GemmHelper gemm_helper;
11881188
1189+ // / constructs plan for contraction C(aC) = A(aA) * B(aB). E.g.
1190+ // / `TensorContractionPlan("i,k", "k,j", "i,j")` constructs a plan
1191+ // / for matrix product.
1192+ // / \param aA einsum annotation for first argument (A)
1193+ // / \param aB einsum annotation for second argument (B)
1194+ // / \param aC einsum annotation for the result (C)
11891195 TensorContractionPlan (Annot const & aA, Annot const & aB, Annot const & aC)
11901196 : A(aA),
11911197 B (aB),
@@ -1207,12 +1213,10 @@ struct TensorContractionPlan {
12071213 }
12081214};
12091215
1210- // / contracts 2 tensors, with 1 plan construction per call. Thus this is
1211- // / inefficient; plan should be constructed separately and then used to for
1212- // / multiple calls (see the variant of this function that takes a plan as an
1213- // / argument)
1216+ // / contracts 2 tensors using the given contraction \p plan .
12141217// / @internal TODO constrain ResultTensorAllocator type so that non-sensical
12151218// / Allocators are prohibited
1219+ // / @return result of the contraction
12161220template <typename ResultTensorAllocator = void , typename TensorA,
12171221 typename TensorB, typename Annot,
12181222 typename = std::enable_if_t <is_tensor_v<TensorA, TensorB> &&
@@ -1300,6 +1304,12 @@ struct TensorHadamardPlan {
13001304
13011305 const bool no_perm, perm_to_c, perm_a, perm_b;
13021306
1307+ // / constructs plan for generalized hadamard product C(aC) = A(aA) * B(aB).
1308+ // / E.g. `TensorHadamardPlan("i,j", "i,j", "j,i")` constructs a plan
1309+ // / for product C(j,i) = A(i,j) B (i,j)
1310+ // / \param aA einsum annotation for first argument (A)
1311+ // / \param aB einsum annotation for second argument (B)
1312+ // / \param aC einsum annotation for the result (C)
13031313 TensorHadamardPlan (Annot const & aA, Annot const & aB, Annot const & aC)
13041314 : A(aA),
13051315 B (aB),
0 commit comments