Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tensorwrapper/detail_/dsl_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,11 @@ class DSLBase {
/// Checks that @p output is a subset of @p input
void assert_is_subset_(const label_type& output,
const label_type& input) const {
if(output.intersection(input).size() < output.unique_index_size())
// Subset would have equality
if(output.intersection(input).size() < output.unique_index_size()) {
throw std::runtime_error(
"Output indices must be a subset of input indices");
}
}

/// Asserts that @p lhs is a permutation of @p rhs
Expand Down
7 changes: 6 additions & 1 deletion include/tensorwrapper/dsl/pairwise_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ class PairwiseParser {
*/
template<typename LHSType, typename RHSType>
void dispatch(LHSType&& lhs, const RHSType& rhs) {
lhs.object().permute_assignment(lhs.labels(), rhs);
if(lhs.labels().is_permutation(rhs.labels()))
lhs.object().permute_assignment(lhs.labels(), rhs);
else { // User just wants us to assign RHS to LHS
lhs.labels() = rhs.labels();
lhs.object().permute_assignment(rhs.labels(), rhs);
}
}

/** @brief Handles adding two expressions together.
Expand Down
14 changes: 7 additions & 7 deletions tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
auto scalar_layout = scalar_physical();
auto vector_layout = vector_physical(2);
auto matrix_layout = matrix_physical(2, 3);
auto tensor_layout = tensor_physical(1, 2, 3);
auto tensor_layout = tensor3_physical(1, 2, 3);

scalar_buffer scalar(eigen_scalar, scalar_layout);
vector_buffer vector(eigen_vector, vector_layout);
Expand Down Expand Up @@ -276,7 +276,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
auto tensor2 = testing::eigen_tensor3<TestType>();

std::array<int, 3> p102{1, 0, 2};
auto l102 = testing::tensor_physical(2, 1, 3);
auto l102 = testing::tensor3_physical(2, 1, 3);
tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102);

auto tijk = tensor("i,j,k");
Expand All @@ -285,7 +285,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
tensor2.addition_assignment("k,j,i", tijk, tjik);

std::array<int, 3> p210{2, 1, 0};
auto l210 = testing::tensor_physical(3, 2, 1);
auto l210 = testing::tensor3_physical(3, 2, 1);
tensor_buffer corr(eigen_tensor.shuffle(p210), l210);
corr.value()(0, 0, 0) = 20.0;
corr.value()(0, 1, 0) = 80.0;
Expand Down Expand Up @@ -392,7 +392,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
auto tensor2 = testing::eigen_tensor3<TestType>();

std::array<int, 3> p102{1, 0, 2};
auto l102 = testing::tensor_physical(2, 1, 3);
auto l102 = testing::tensor3_physical(2, 1, 3);
tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102);

auto tijk = tensor("i,j,k");
Expand All @@ -401,7 +401,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
tensor2.subtraction_assignment("k,j,i", tijk, tjik);

std::array<int, 3> p210{2, 1, 0};
auto l210 = testing::tensor_physical(3, 2, 1);
auto l210 = testing::tensor3_physical(3, 2, 1);
tensor_buffer corr(eigen_tensor.shuffle(p210), l210);
corr.value()(0, 0, 0) = 0.0;
corr.value()(0, 1, 0) = 0.0;
Expand Down Expand Up @@ -631,7 +631,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
auto tensor2 = testing::eigen_tensor3<TestType>();

std::array<int, 3> p102{1, 0, 2};
auto l102 = testing::tensor_physical(2, 1, 3);
auto l102 = testing::tensor3_physical(2, 1, 3);
tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102);

auto tijk = tensor("i,j,k");
Expand All @@ -640,7 +640,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) {
tensor2.multiplication_assignment("k,j,i", tijk, tjik);

std::array<int, 3> p210{2, 1, 0};
auto l210 = testing::tensor_physical(3, 2, 1);
auto l210 = testing::tensor3_physical(3, 2, 1);
tensor_buffer corr(eigen_tensor.shuffle(p210), l210);
corr.value()(0, 0, 0) = 100.0;
corr.value()(0, 1, 0) = 1600.0;
Expand Down
13 changes: 10 additions & 3 deletions tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ using namespace tensorwrapper;
TEMPLATE_LIST_TEST_CASE("DSL", "", testing::dsl_types) {
using object_type = TestType;

auto scalar_values = testing::scalar_values();
auto vector_values = testing::vector_values();
auto matrix_values = testing::matrix_values();
auto scalar_values = testing::scalar_values();
auto vector_values = testing::vector_values();
auto matrix_values = testing::matrix_values();
auto tensor4_values = testing::tensor4_values();

auto value0 = std::get<object_type>(scalar_values);
auto value1 = std::get<object_type>(vector_values);
auto value2 = std::get<object_type>(matrix_values);
auto value4 = std::get<object_type>(tensor4_values);

SECTION("assignment") {
value0("i,j") = value2("i,j");
Expand Down Expand Up @@ -61,6 +63,11 @@ TEMPLATE_LIST_TEST_CASE("DSL", "", testing::dsl_types) {

value1.multiplication_assignment("i,j", value2("i,j"), value2("i,j"));
REQUIRE(value1.are_equal(value0));

value0("m,n") = value2("l,s") * value4("m,n,s,l");
value1.multiplication_assignment("m,n", value2("l,s"),
value4("m,n,s,l"));
REQUIRE(value1.are_equal(value0));
}

SECTION("scalar_multiplication") {
Expand Down
9 changes: 9 additions & 0 deletions tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ TEST_CASE("DummyIndices<std::string>") {
REQUIRE(matrix.concatenation(vector) == dummy_indices_type("i,j,i"));
REQUIRE(matrix.concatenation(matrix) == dummy_indices_type("i,j,i,j"));
REQUIRE(matrix.concatenation(matrix2) == dummy_indices_type("i,j,k,l"));

auto x = matrix.concatenation(dummy_indices_type("i,j,l,s"));
REQUIRE(x == dummy_indices_type("i,j,i,j,l,s"));
}

SECTION("intersection") {
Expand All @@ -298,6 +301,12 @@ TEST_CASE("DummyIndices<std::string>") {
REQUIRE(matrix.intersection(vector) == dummy_indices_type("i"));
REQUIRE(matrix.intersection(matrix) == dummy_indices_type("i,j"));
REQUIRE(matrix.intersection(matrix2) == dummy_indices_type(""));

auto x = matrix.intersection(dummy_indices_type("i,j,l,s"));
REQUIRE(x == dummy_indices_type("i,j"));

auto y = matrix.intersection(dummy_indices_type("i,j,i,j,l,s"));
REQUIRE(x == dummy_indices_type("i,j"));
}

SECTION("difference") {
Expand Down
21 changes: 21 additions & 0 deletions tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,25 @@ inline auto matrix_values() {
Tensor{{1.0, 2.0}, {3.0, 4.0}}};
}

inline auto tensor3_values() {
return dsl_types{
smooth_tensor3(),
tensorwrapper::symmetry::Group(3),
tensorwrapper::sparsity::Pattern(3),
tensor3_logical(),
tensor3_physical(),
Tensor{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}};
}

inline auto tensor4_values() {
return dsl_types{
smooth_tensor4(),
tensorwrapper::symmetry::Group(4),
tensorwrapper::sparsity::Pattern(4),
tensor4_logical(),
tensor4_physical(),
Tensor{{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}},
{{{9.0, 10.0}, {11.0, 12.0}}, {{13.0, 14.0}, {15.0, 16.0}}}}};
}

} // namespace tensorwrapper::testing
22 changes: 16 additions & 6 deletions tests/cxx/unit_tests/tensorwrapper/testing/layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@ inline auto matrix_logical(std::size_t i = 10, std::size_t j = 10) {
return tensorwrapper::layout::Logical(smooth_matrix(i, j));
}

inline auto tensor_logical(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10) {
return tensorwrapper::layout::Logical(smooth_tensor(i, j, k));
inline auto tensor3_logical(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10) {
return tensorwrapper::layout::Logical(smooth_tensor3(i, j, k));
}

inline auto tensor4_logical(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10, std::size_t l = 10) {
return tensorwrapper::layout::Logical(smooth_tensor4(i, j, k, l));
}

// -----------------------------------------------------------------------------
Expand All @@ -57,9 +62,14 @@ inline auto matrix_physical(std::size_t i = 10, std::size_t j = 10) {
return tensorwrapper::layout::Physical(smooth_matrix(i, j));
}

inline auto tensor_physical(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10) {
return tensorwrapper::layout::Physical(smooth_tensor(i, j, k));
inline auto tensor3_physical(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10) {
return tensorwrapper::layout::Physical(smooth_tensor3(i, j, k));
}

inline auto tensor4_physical(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10, std::size_t l = 10) {
return tensorwrapper::layout::Physical(smooth_tensor4(i, j, k, l));
}

} // namespace tensorwrapper::testing
9 changes: 7 additions & 2 deletions tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ inline auto smooth_matrix(std::size_t i = 10, std::size_t j = 10) {
return tensorwrapper::shape::Smooth{i, j};
}

inline auto smooth_tensor(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10) {
inline auto smooth_tensor3(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10) {
return tensorwrapper::shape::Smooth{i, j, k};
}

inline auto smooth_tensor4(std::size_t i = 10, std::size_t j = 10,
std::size_t k = 10, std::size_t l = 10) {
return tensorwrapper::shape::Smooth{i, j, k, l};
}

} // namespace tensorwrapper::testing