diff --git a/include/ygm/container/bag.hpp b/include/ygm/container/bag.hpp index 0bccf398..0f6824e8 100644 --- a/include/ygm/container/bag.hpp +++ b/include/ygm/container/bag.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -28,6 +29,7 @@ namespace ygm::container { */ template class bag : public detail::base_async_insert_value, std::tuple>, + public detail::base_contains, std::tuple>, public detail::base_count, std::tuple>, public detail::base_misc, std::tuple>, public detail::base_iterators>, @@ -259,6 +261,17 @@ class bag : public detail::base_async_insert_value, std::tuple>, return std::count(m_local_bag.begin(), m_local_bag.end(), val); } + /** + * @brief Check if a value exists locally + * + * @param val Value to check for + * @return True if value exists locally, false otherwise + */ + bool local_contains(const value_type &val) const { + return std::find(m_local_bag.begin(), m_local_bag.end(), val) != + m_local_bag.end(); + } + /** * @brief Execute a functor on every locally-held item * diff --git a/include/ygm/container/counting_set.hpp b/include/ygm/container/counting_set.hpp index 4f51ef57..7e834967 100644 --- a/include/ygm/container/counting_set.hpp +++ b/include/ygm/container/counting_set.hpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -25,6 +26,7 @@ namespace ygm::container { template class counting_set : public detail::base_count, std::tuple>, + public detail::base_contains, std::tuple>, public detail::base_misc, std::tuple>, public detail::base_iterators>, public detail::base_iteration_key_value, @@ -286,6 +288,16 @@ class counting_set return local_count; } + /** + * @brief Check if a locally-held item exists + * + * @param val Value to check for + * @return true if value exists locally, false otherwise + */ + bool local_contains(const key_type &key) const { + return m_map.local_contains(key); + } + /** * @brief Count the total number of items counted * diff --git a/include/ygm/container/detail/base_contains.hpp b/include/ygm/container/detail/base_contains.hpp new file mode 100644 index 00000000..2ecce8e9 --- /dev/null +++ b/include/ygm/container/detail/base_contains.hpp @@ -0,0 +1,35 @@ +// Copyright 2019-2025 Lawrence Livermore National Security, LLC and other YGM +// Project Developers. See the top-level COPYRIGHT file for details. +// +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +namespace ygm::container::detail { + +/** + * @brief Curiously-recurring template parameter struct that provides + * count operation + */ +template +struct base_contains { + /** + * @brief Checks for the presence of a value within a container. + * + * @param value Value to search for within container (key in the case of + * containers with keys) + * @return True if `value` exists in container; false otherwise. + */ + bool contains( + const typename std::tuple_element<0, for_all_args>::type& value) const { + const derived_type* derived_this = static_cast(this); + derived_this->comm().barrier(); + return ::ygm::logical_or(derived_this->local_contains(value), + derived_this->comm()); + } +}; + +} // namespace ygm::container::detail diff --git a/include/ygm/container/map.hpp b/include/ygm/container/map.hpp index d4594f4a..3c84c6cf 100644 --- a/include/ygm/container/map.hpp +++ b/include/ygm/container/map.hpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -28,6 +29,7 @@ class map public detail::base_async_insert_or_assign, std::tuple>, public detail::base_misc, std::tuple>, + public detail::base_contains, std::tuple>, public detail::base_count, std::tuple>, public detail::base_async_reduce, std::tuple>, public detail::base_async_erase_key, @@ -562,6 +564,16 @@ class map return m_local_map.count(key); } + /** + * @brief Check if a key exists locally + * + * @param key key to check for + * @return True if `key` exists locally, false otherwise + */ + bool local_contains(const key_type& key) const { + return m_local_map.contains(key); + } + // void serialize(const std::string& fname) { m_impl.serialize(fname); } // void deserialize(const std::string& fname) { m_impl.deserialize(fname); } diff --git a/include/ygm/container/set.hpp b/include/ygm/container/set.hpp index 31928c85..ce9dacea 100644 --- a/include/ygm/container/set.hpp +++ b/include/ygm/container/set.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ class set public detail::base_batch_erase_key, std::tuple>, public detail::base_async_contains, std::tuple>, public detail::base_async_insert_contains, std::tuple>, + public detail::base_contains, std::tuple>, public detail::base_count, std::tuple>, public detail::base_misc, std::tuple>, public detail::base_iterators>, @@ -228,6 +230,16 @@ class set return m_local_set.count(val); } + /** + * @brief Check if a value exists locally + * + * @param val Value to check for + * @return True if value exists locally, false otherwise + */ + bool local_contains(const value_type &val) const { + return m_local_set.contains(val); + } + /** * @brief Get the number of elements stored on the local process. * diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6a2b273f..020324c3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -48,7 +48,7 @@ add_ygm_test(test_set) add_ygm_test(test_bag) #add_ygm_test(test_tagged_bag) add_ygm_test(test_array) -#add_ygm_test(test_counting_set) +# add_ygm_test(test_counting_set) add_ygm_test(test_disjoint_set) #add_ygm_test(test_container_serialization) add_ygm_test(test_line_parser) diff --git a/test/test_bag.cpp b/test/test_bag.cpp index 9b844e35..77584624 100644 --- a/test/test_bag.cpp +++ b/test/test_bag.cpp @@ -40,6 +40,12 @@ int main(int argc, char** argv) { YGM_ASSERT_RELEASE(bbag.count("red") == 1); YGM_ASSERT_RELEASE(bbag.size() == 3); + // test contains. + YGM_ASSERT_RELEASE(bbag.contains("dog")); + YGM_ASSERT_RELEASE(bbag.contains("apple")); + YGM_ASSERT_RELEASE(bbag.contains("red")); + YGM_ASSERT_RELEASE(!bbag.contains("blue")); + for (auto& value : bbag) { world.cout(value); } @@ -158,6 +164,12 @@ int main(int argc, char** argv) { YGM_ASSERT_RELEASE(bbag.count("apple") == (size_t)world.size()); YGM_ASSERT_RELEASE(bbag.count("red") == (size_t)world.size()); + // test contains + YGM_ASSERT_RELEASE(bbag.contains("dog")); + YGM_ASSERT_RELEASE(bbag.contains("apple")); + YGM_ASSERT_RELEASE(bbag.contains("red")); + YGM_ASSERT_RELEASE(!bbag.contains("blue")); + { std::vector all_data; bbag.gather(all_data, 0); diff --git a/test/test_counting_set.cpp b/test/test_counting_set.cpp index e43a6c9c..9f88d3fc 100644 --- a/test/test_counting_set.cpp +++ b/test/test_counting_set.cpp @@ -44,6 +44,12 @@ int main(int argc, char **argv) { YGM_ASSERT_RELEASE(count_map["dog"] == 1); YGM_ASSERT_RELEASE(count_map["apple"] == 1); YGM_ASSERT_RELEASE(count_map.count("cat") == 0); + + // test contains + YGM_ASSERT_RELEASE(cset.contains("dog")); + YGM_ASSERT_RELEASE(cset.contains("apple")); + YGM_ASSERT_RELEASE(cset.contains("red")); + YGM_ASSERT_RELEASE(!cset.contains("blue")); } // @@ -66,6 +72,12 @@ int main(int argc, char **argv) { YGM_ASSERT_RELEASE(cset.count("cat") == 0); YGM_ASSERT_RELEASE(cset.count_all() == 3 * (size_t)world.size()); + + // test contains + YGM_ASSERT_RELEASE(cset.contains("dog")); + YGM_ASSERT_RELEASE(cset.contains("apple")); + YGM_ASSERT_RELEASE(cset.contains("red")); + YGM_ASSERT_RELEASE(!cset.contains("blue")); } // @@ -113,6 +125,12 @@ int main(int argc, char **argv) { YGM_ASSERT_RELEASE(cset.count("dog") == 0); YGM_ASSERT_RELEASE(cset.count("apple") == 0); YGM_ASSERT_RELEASE(cset.count("red") == 0); + + // test contains + YGM_ASSERT_RELEASE(!cset.contains("dog")); + YGM_ASSERT_RELEASE(!cset.contains("apple")); + YGM_ASSERT_RELEASE(!cset.contains("red")); + YGM_ASSERT_RELEASE(!cset.contains("blue")); } // // @@ -166,6 +184,12 @@ int main(int argc, char **argv) { YGM_ASSERT_RELEASE(cset2.count("bird") == (size_t)world.size()); YGM_ASSERT_RELEASE(cset2.count("red") == 0); YGM_ASSERT_RELEASE(cset2.size() == 3); + + // test contains + YGM_ASSERT_RELEASE(cset2.contains("dog")); + YGM_ASSERT_RELEASE(cset2.contains("cat")); + YGM_ASSERT_RELEASE(cset2.contains("bird")); + YGM_ASSERT_RELEASE(!cset2.contains("red")); } // diff --git a/test/test_map.cpp b/test/test_map.cpp index 8411c10f..ef192f41 100644 --- a/test/test_map.cpp +++ b/test/test_map.cpp @@ -40,6 +40,12 @@ int main(int argc, char **argv) { YGM_ASSERT_RELEASE(smap.count("dog") == 1); YGM_ASSERT_RELEASE(smap.count("apple") == 1); YGM_ASSERT_RELEASE(smap.count("red") == 1); + + // test contains. + YGM_ASSERT_RELEASE(smap.contains("dog")); + YGM_ASSERT_RELEASE(smap.contains("apple")); + YGM_ASSERT_RELEASE(smap.contains("red")); + YGM_ASSERT_RELEASE(!smap.contains("blue")); } // @@ -54,6 +60,12 @@ int main(int argc, char **argv) { YGM_ASSERT_RELEASE(smap.count("dog") == 1); YGM_ASSERT_RELEASE(smap.count("apple") == 1); YGM_ASSERT_RELEASE(smap.count("red") == 1); + + // test contains. + YGM_ASSERT_RELEASE(smap.contains("dog")); + YGM_ASSERT_RELEASE(smap.contains("apple")); + YGM_ASSERT_RELEASE(smap.contains("red")); + YGM_ASSERT_RELEASE(!smap.contains("blue")); } // @@ -135,13 +147,21 @@ int main(int argc, char **argv) { YGM_ASSERT_RELEASE(smap.size() == 2); + // test contains. + YGM_ASSERT_RELEASE(smap.contains("dog")); + YGM_ASSERT_RELEASE(smap.contains("cat")); + YGM_ASSERT_RELEASE(!smap.contains("red")); + YGM_ASSERT_RELEASE(!smap.contains("blue")); + if (world.rank() == 0) { smap.async_erase("dog"); } YGM_ASSERT_RELEASE(smap.count("dog") == 0); YGM_ASSERT_RELEASE(smap.size() == 1); + YGM_ASSERT_RELEASE(!smap.contains("dog")); smap.async_erase("cat"); YGM_ASSERT_RELEASE(smap.count("cat") == 0); + YGM_ASSERT_RELEASE(!smap.contains("cat")); YGM_ASSERT_RELEASE(smap.size() == 0); } @@ -259,9 +279,17 @@ int main(int argc, char **argv) { YGM_ASSERT_RELEASE(smap.count("dog") == 1); YGM_ASSERT_RELEASE(smap.count("apple") == 1); YGM_ASSERT_RELEASE(smap.count("red") == 1); + + YGM_ASSERT_RELEASE(!smap.contains("car")); smap.async_insert_or_assign("car", "truck"); YGM_ASSERT_RELEASE(smap.size() == 4); YGM_ASSERT_RELEASE(smap.count("car") == 1); + + // test contains. + YGM_ASSERT_RELEASE(smap.contains("dog")); + YGM_ASSERT_RELEASE(smap.contains("car")); + YGM_ASSERT_RELEASE(smap.contains("red")); + YGM_ASSERT_RELEASE(!smap.contains("blue")); } // Test batch erase from set diff --git a/test/test_set.cpp b/test/test_set.cpp index d3c976d7..43e21ed9 100644 --- a/test/test_set.cpp +++ b/test/test_set.cpp @@ -39,6 +39,12 @@ int main(int argc, char** argv) { YGM_ASSERT_RELEASE(sset.count("apple") == 1); YGM_ASSERT_RELEASE(sset.size() == 3); + // test contains. + YGM_ASSERT_RELEASE(sset.contains("dog")); + YGM_ASSERT_RELEASE(sset.contains("apple")); + YGM_ASSERT_RELEASE(sset.contains("red")); + YGM_ASSERT_RELEASE(!sset.contains("blue")); + ygm::container::set iset(world); if (world.rank() == 0) { iset.async_insert(42); @@ -49,6 +55,12 @@ int main(int argc, char** argv) { YGM_ASSERT_RELEASE(iset.count(7) == 1); YGM_ASSERT_RELEASE(iset.count(100) == 1); YGM_ASSERT_RELEASE(iset.size() == 3); + + // test contains. + YGM_ASSERT_RELEASE(iset.contains(42)); + YGM_ASSERT_RELEASE(iset.contains(7)); + YGM_ASSERT_RELEASE(iset.contains(100)); + YGM_ASSERT_RELEASE(!iset.contains(3)); } // @@ -65,6 +77,12 @@ int main(int argc, char** argv) { YGM_ASSERT_RELEASE(sset.count("apple") == 1); YGM_ASSERT_RELEASE(sset.count("red") == 1); YGM_ASSERT_RELEASE(sset.size() == 3); + + // test contains. + YGM_ASSERT_RELEASE(sset.contains("dog")); + YGM_ASSERT_RELEASE(sset.contains("apple")); + YGM_ASSERT_RELEASE(sset.contains("red")); + YGM_ASSERT_RELEASE(!sset.contains("blue")); } // @@ -83,6 +101,12 @@ int main(int argc, char** argv) { sset.async_erase("dog"); YGM_ASSERT_RELEASE(sset.count("dog") == 0); YGM_ASSERT_RELEASE(sset.size() == 2); + + // test contains. + YGM_ASSERT_RELEASE(sset.contains("apple")); + YGM_ASSERT_RELEASE(sset.contains("red")); + YGM_ASSERT_RELEASE(!sset.contains("dog")); + YGM_ASSERT_RELEASE(!sset.contains("blue")); } // @@ -262,6 +286,13 @@ int main(int argc, char** argv) { sset.async_insert("car"); YGM_ASSERT_RELEASE(sset.size() == 4); YGM_ASSERT_RELEASE(sset.count("car") == 1); + + // test contains. + YGM_ASSERT_RELEASE(sset.contains("apple")); + YGM_ASSERT_RELEASE(sset.contains("red")); + YGM_ASSERT_RELEASE(sset.contains("dog")); + YGM_ASSERT_RELEASE(sset.contains("car")); + YGM_ASSERT_RELEASE(!sset.contains("blue")); } // @@ -279,6 +310,12 @@ int main(int argc, char** argv) { YGM_ASSERT_RELEASE(sset2.count("dog") == 1); YGM_ASSERT_RELEASE(sset2.count("apple") == 1); YGM_ASSERT_RELEASE(sset2.count("red") == 1); + + // test contains. + YGM_ASSERT_RELEASE(sset2.contains("apple")); + YGM_ASSERT_RELEASE(sset2.contains("red")); + YGM_ASSERT_RELEASE(sset2.contains("dog")); + YGM_ASSERT_RELEASE(!sset2.contains("blue")); } // //