diff --git a/include/tensorwrapper/detail_/dsl_base.hpp b/include/tensorwrapper/detail_/dsl_base.hpp index fe044881..94b9cf4e 100644 --- a/include/tensorwrapper/detail_/dsl_base.hpp +++ b/include/tensorwrapper/detail_/dsl_base.hpp @@ -274,9 +274,11 @@ class DSLBase { /// Checks that @p output is a subset of @p input void assert_is_subset_(const label_type& output, const label_type& input) const { - if(output.intersection(input).size() < output.unique_index_size()) + // Subset would have equality + if(output.intersection(input).size() < output.unique_index_size()) { throw std::runtime_error( "Output indices must be a subset of input indices"); + } } /// Asserts that @p lhs is a permutation of @p rhs diff --git a/include/tensorwrapper/dsl/pairwise_parser.hpp b/include/tensorwrapper/dsl/pairwise_parser.hpp index e88a700d..a7caf4d0 100644 --- a/include/tensorwrapper/dsl/pairwise_parser.hpp +++ b/include/tensorwrapper/dsl/pairwise_parser.hpp @@ -59,7 +59,12 @@ class PairwiseParser { */ template void dispatch(LHSType&& lhs, const RHSType& rhs) { - lhs.object().permute_assignment(lhs.labels(), rhs); + if(lhs.labels().is_permutation(rhs.labels())) + lhs.object().permute_assignment(lhs.labels(), rhs); + else { // User just wants us to assign RHS to LHS + lhs.labels() = rhs.labels(); + lhs.object().permute_assignment(rhs.labels(), rhs); + } } /** @brief Handles adding two expressions together. diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp index f3cc81db..d11a2943 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp @@ -79,7 +79,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) { auto scalar_layout = scalar_physical(); auto vector_layout = vector_physical(2); auto matrix_layout = matrix_physical(2, 3); - auto tensor_layout = tensor_physical(1, 2, 3); + auto tensor_layout = tensor3_physical(1, 2, 3); scalar_buffer scalar(eigen_scalar, scalar_layout); vector_buffer vector(eigen_vector, vector_layout); @@ -276,7 +276,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) { auto tensor2 = testing::eigen_tensor3(); std::array p102{1, 0, 2}; - auto l102 = testing::tensor_physical(2, 1, 3); + auto l102 = testing::tensor3_physical(2, 1, 3); tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102); auto tijk = tensor("i,j,k"); @@ -285,7 +285,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) { tensor2.addition_assignment("k,j,i", tijk, tjik); std::array p210{2, 1, 0}; - auto l210 = testing::tensor_physical(3, 2, 1); + auto l210 = testing::tensor3_physical(3, 2, 1); tensor_buffer corr(eigen_tensor.shuffle(p210), l210); corr.value()(0, 0, 0) = 20.0; corr.value()(0, 1, 0) = 80.0; @@ -392,7 +392,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) { auto tensor2 = testing::eigen_tensor3(); std::array p102{1, 0, 2}; - auto l102 = testing::tensor_physical(2, 1, 3); + auto l102 = testing::tensor3_physical(2, 1, 3); tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102); auto tijk = tensor("i,j,k"); @@ -401,7 +401,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) { tensor2.subtraction_assignment("k,j,i", tijk, tjik); std::array p210{2, 1, 0}; - auto l210 = testing::tensor_physical(3, 2, 1); + auto l210 = testing::tensor3_physical(3, 2, 1); tensor_buffer corr(eigen_tensor.shuffle(p210), l210); corr.value()(0, 0, 0) = 0.0; corr.value()(0, 1, 0) = 0.0; @@ -631,7 +631,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) { auto tensor2 = testing::eigen_tensor3(); std::array p102{1, 0, 2}; - auto l102 = testing::tensor_physical(2, 1, 3); + auto l102 = testing::tensor3_physical(2, 1, 3); tensor_buffer tensor102(eigen_tensor.shuffle(p102), l102); auto tijk = tensor("i,j,k"); @@ -640,7 +640,7 @@ TEMPLATE_LIST_TEST_CASE("Eigen", "", types2test) { tensor2.multiplication_assignment("k,j,i", tijk, tjik); std::array p210{2, 1, 0}; - auto l210 = testing::tensor_physical(3, 2, 1); + auto l210 = testing::tensor3_physical(3, 2, 1); tensor_buffer corr(eigen_tensor.shuffle(p210), l210); corr.value()(0, 0, 0) = 100.0; corr.value()(0, 1, 0) = 1600.0; diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp index 3565b558..b70d2b4f 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp @@ -22,13 +22,15 @@ using namespace tensorwrapper; TEMPLATE_LIST_TEST_CASE("DSL", "", testing::dsl_types) { using object_type = TestType; - auto scalar_values = testing::scalar_values(); - auto vector_values = testing::vector_values(); - auto matrix_values = testing::matrix_values(); + auto scalar_values = testing::scalar_values(); + auto vector_values = testing::vector_values(); + auto matrix_values = testing::matrix_values(); + auto tensor4_values = testing::tensor4_values(); auto value0 = std::get(scalar_values); auto value1 = std::get(vector_values); auto value2 = std::get(matrix_values); + auto value4 = std::get(tensor4_values); SECTION("assignment") { value0("i,j") = value2("i,j"); @@ -61,6 +63,11 @@ TEMPLATE_LIST_TEST_CASE("DSL", "", testing::dsl_types) { value1.multiplication_assignment("i,j", value2("i,j"), value2("i,j")); REQUIRE(value1.are_equal(value0)); + + value0("m,n") = value2("l,s") * value4("m,n,s,l"); + value1.multiplication_assignment("m,n", value2("l,s"), + value4("m,n,s,l")); + REQUIRE(value1.are_equal(value0)); } SECTION("scalar_multiplication") { diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp index 078a54fa..65cdbde0 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/dummy_indices.cpp @@ -280,6 +280,9 @@ TEST_CASE("DummyIndices") { REQUIRE(matrix.concatenation(vector) == dummy_indices_type("i,j,i")); REQUIRE(matrix.concatenation(matrix) == dummy_indices_type("i,j,i,j")); REQUIRE(matrix.concatenation(matrix2) == dummy_indices_type("i,j,k,l")); + + auto x = matrix.concatenation(dummy_indices_type("i,j,l,s")); + REQUIRE(x == dummy_indices_type("i,j,i,j,l,s")); } SECTION("intersection") { @@ -298,6 +301,12 @@ TEST_CASE("DummyIndices") { REQUIRE(matrix.intersection(vector) == dummy_indices_type("i")); REQUIRE(matrix.intersection(matrix) == dummy_indices_type("i,j")); REQUIRE(matrix.intersection(matrix2) == dummy_indices_type("")); + + auto x = matrix.intersection(dummy_indices_type("i,j,l,s")); + REQUIRE(x == dummy_indices_type("i,j")); + + auto y = matrix.intersection(dummy_indices_type("i,j,i,j,l,s")); + REQUIRE(x == dummy_indices_type("i,j")); } SECTION("difference") { diff --git a/tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp b/tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp index 54e4f33f..14712c82 100644 --- a/tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp +++ b/tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp @@ -54,4 +54,25 @@ inline auto matrix_values() { Tensor{{1.0, 2.0}, {3.0, 4.0}}}; } +inline auto tensor3_values() { + return dsl_types{ + smooth_tensor3(), + tensorwrapper::symmetry::Group(3), + tensorwrapper::sparsity::Pattern(3), + tensor3_logical(), + tensor3_physical(), + Tensor{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}}; +} + +inline auto tensor4_values() { + return dsl_types{ + smooth_tensor4(), + tensorwrapper::symmetry::Group(4), + tensorwrapper::sparsity::Pattern(4), + tensor4_logical(), + tensor4_physical(), + Tensor{{{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}, + {{{9.0, 10.0}, {11.0, 12.0}}, {{13.0, 14.0}, {15.0, 16.0}}}}}; +} + } // namespace tensorwrapper::testing \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/testing/layouts.hpp b/tests/cxx/unit_tests/tensorwrapper/testing/layouts.hpp index f1f71247..3da5f104 100644 --- a/tests/cxx/unit_tests/tensorwrapper/testing/layouts.hpp +++ b/tests/cxx/unit_tests/tensorwrapper/testing/layouts.hpp @@ -36,9 +36,14 @@ inline auto matrix_logical(std::size_t i = 10, std::size_t j = 10) { return tensorwrapper::layout::Logical(smooth_matrix(i, j)); } -inline auto tensor_logical(std::size_t i = 10, std::size_t j = 10, - std::size_t k = 10) { - return tensorwrapper::layout::Logical(smooth_tensor(i, j, k)); +inline auto tensor3_logical(std::size_t i = 10, std::size_t j = 10, + std::size_t k = 10) { + return tensorwrapper::layout::Logical(smooth_tensor3(i, j, k)); +} + +inline auto tensor4_logical(std::size_t i = 10, std::size_t j = 10, + std::size_t k = 10, std::size_t l = 10) { + return tensorwrapper::layout::Logical(smooth_tensor4(i, j, k, l)); } // ----------------------------------------------------------------------------- @@ -57,9 +62,14 @@ inline auto matrix_physical(std::size_t i = 10, std::size_t j = 10) { return tensorwrapper::layout::Physical(smooth_matrix(i, j)); } -inline auto tensor_physical(std::size_t i = 10, std::size_t j = 10, - std::size_t k = 10) { - return tensorwrapper::layout::Physical(smooth_tensor(i, j, k)); +inline auto tensor3_physical(std::size_t i = 10, std::size_t j = 10, + std::size_t k = 10) { + return tensorwrapper::layout::Physical(smooth_tensor3(i, j, k)); +} + +inline auto tensor4_physical(std::size_t i = 10, std::size_t j = 10, + std::size_t k = 10, std::size_t l = 10) { + return tensorwrapper::layout::Physical(smooth_tensor4(i, j, k, l)); } } // namespace tensorwrapper::testing \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp b/tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp index 074dde12..28634628 100644 --- a/tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp +++ b/tests/cxx/unit_tests/tensorwrapper/testing/shapes.hpp @@ -34,9 +34,14 @@ inline auto smooth_matrix(std::size_t i = 10, std::size_t j = 10) { return tensorwrapper::shape::Smooth{i, j}; } -inline auto smooth_tensor(std::size_t i = 10, std::size_t j = 10, - std::size_t k = 10) { +inline auto smooth_tensor3(std::size_t i = 10, std::size_t j = 10, + std::size_t k = 10) { return tensorwrapper::shape::Smooth{i, j, k}; } +inline auto smooth_tensor4(std::size_t i = 10, std::size_t j = 10, + std::size_t k = 10, std::size_t l = 10) { + return tensorwrapper::shape::Smooth{i, j, k, l}; +} + } // namespace tensorwrapper::testing \ No newline at end of file