diff --git a/docs/source/index.rst b/docs/source/index.rst index d407148..a00b14f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -25,6 +25,7 @@ working with floating-point (FP) numbers in a type-erased manner. developer/index adding_a_new_type + writing_a_visitor .. toctree:: :maxdepth: 2 diff --git a/docs/source/writing_a_visitor.rst b/docs/source/writing_a_visitor.rst new file mode 100644 index 0000000..0d764b9 --- /dev/null +++ b/docs/source/writing_a_visitor.rst @@ -0,0 +1,42 @@ +.. Copyright 2025 NWChemEx-Project +.. +.. Licensed under the Apache License, Version 2.0 (the "License"); +.. you may not use this file except in compliance with the License. +.. You may obtain a copy of the License at +.. +.. http://www.apache.org/licenses/LICENSE-2.0 +.. +.. Unless required by applicable law or agreed to in writing, software +.. distributed under the License is distributed on an "AS IS" BASIS, +.. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +.. See the License for the specific language governing permissions and +.. limitations under the License. + +################# +Writing a Visitor +################# + +To avoid assuming a particular set of operations WTF relies on the visitor +pattern. This page should eventually be a tutorial, but for now is just a set +of notes. + +Assume that ``visitor`` is an object of type ``T``. The dispatch function will +additionally require a ``std::tuple`` where ``FPList...`` is a list +of the floating-point types you want to be able to support. Then assume we are +trying to dispatch based on the wrapped types of ``N`` ``FloatBuffer`` objects. + +- ``T`` must be a callable type, i.e., define ``operator()``. Functions, + lambdas, and functors all satisfy this criterion +- ``operator()`` may be overloaded. If it is overloaded, the usual C++ overload + resolution rules will be used to select the appropriate overload. +- Dispatching considers the type of the floating point buffers and the number of + floating point buffers. +- ``operator()`` overloads must accent ``N`` parameters. +- ``operator()`` must possess an overload capable of supporting any permutation + with replacement of the types in ``FPLists...``. + + - The easiest way to satisfy this is to have a templated overload with ``N`` + template parameters (one for the type of each positional argument). + +- Each ``FloatBuffer`` will be unwrapped to a ``std::span`` object where + ``U`` is the type of the wrapped floating point objects. diff --git a/include/wtf/buffer/detail_/contiguous_model.hpp b/include/wtf/buffer/detail_/contiguous_model.hpp index 6e695c2..2afe01d 100644 --- a/include/wtf/buffer/detail_/contiguous_model.hpp +++ b/include/wtf/buffer/detail_/contiguous_model.hpp @@ -134,7 +134,7 @@ class ContiguousModel : public BufferHolder { * * @throw None No-throw guarantee. */ - auto span() { return std::span(data(), size()); } + auto span() { return std::span(data(), size()); } /** @brief Returns the wrapped data as a std::span. * @@ -145,7 +145,7 @@ class ContiguousModel : public BufferHolder { * * @throw None No-throw guarantee. */ - auto span() const { return std::span(data(), size()); } + auto span() const { return std::span(data(), size()); } /** @brief Compares the elements in the buffer for exact equality. * @@ -243,7 +243,7 @@ class ContiguousModel : public BufferHolder { */ template auto visit_contiguous_model(Visitor&& visitor, Args&&... args) { - auto lambda = [=](auto&&... inner_args) { + auto lambda = [&](auto&&... inner_args) { return visitor(inner_args.span()...); }; return wtf::detail_::dispatch( diff --git a/include/wtf/buffer/float_buffer.hpp b/include/wtf/buffer/float_buffer.hpp index e09ab2c..0eb9dda 100644 --- a/include/wtf/buffer/float_buffer.hpp +++ b/include/wtf/buffer/float_buffer.hpp @@ -309,6 +309,10 @@ class FloatBuffer { template friend auto visit_contiguous_buffer(Visitor&& visitor, Args&&... args); + holder_type& holder_() { return *m_pholder_; } + + const holder_type& holder_() const { return *m_pholder_; } + template auto& downcast_() { using model_type = detail_::ContiguousModel; @@ -393,7 +397,7 @@ std::span contiguous_buffer_cast(FloatBuffer& buffer) { template auto visit_contiguous_buffer(Visitor&& visitor, Args&&... args) { return detail_::visit_contiguous_model( - std::forward(visitor), *args.m_pholder_...); + std::forward(visitor), args.holder_()...); } // ----------------------------------------------------------------------------- diff --git a/tests/unit_tests/wtf/buffer/float_buffer.cpp b/tests/unit_tests/wtf/buffer/float_buffer.cpp index d8ca89a..49385df 100644 --- a/tests/unit_tests/wtf/buffer/float_buffer.cpp +++ b/tests/unit_tests/wtf/buffer/float_buffer.cpp @@ -253,8 +253,8 @@ TEMPLATE_LIST_TEST_CASE("contiguous_buffer_cast", "[wtf]", all_fp_types) { REQUIRE(span[2] == three); } -struct CheckVisitContiguousModel { - CheckVisitContiguousModel(float* pdataf, double* pdatad) : +struct ConstVisitorConstSpan { + ConstVisitorConstSpan(float* pdataf, double* pdatad) : pdataf_corr(pdataf), pdatad_corr(pdatad) {} auto operator()(std::span span) const { @@ -283,13 +283,64 @@ struct CheckVisitContiguousModel { double* pdatad_corr; }; +struct VisitorConstSpan { + VisitorConstSpan(float* pdataf, double* pdatad) : + pdataf_corr(pdataf), pdatad_corr(pdatad) {} + + template + auto operator()(std::span span) { + if constexpr(std::is_same_v, float>) { + REQUIRE(span.data() == pdataf_corr); + m_called_float = true; + } else { + REQUIRE(span.data() == pdatad_corr); + m_called_double = true; + } + } + + template + auto operator()(std::span span) { + if constexpr(std::is_same_v, float>) { + REQUIRE(span.data() == pdataf_corr); + m_called_cfloat = true; + } else { + REQUIRE(span.data() == pdatad_corr); + m_called_cdouble = true; + } + REQUIRE(span.size() == 3); + } + + auto operator()(std::span lhs, std::span rhs) { + REQUIRE(lhs.data() == pdataf_corr); + REQUIRE(lhs.size() == 3); + REQUIRE(rhs.data() == pdatad_corr); + REQUIRE(rhs.size() == 3); + m_called_float_double = true; + } + + template + auto operator()(std::span lhs, std::span rhs) const { + throw std::runtime_error("Only float, double supported"); + } + + bool m_called_float = false; + bool m_called_cfloat = false; + bool m_called_double = false; + bool m_called_cdouble = false; + bool m_called_float_double = false; + + float* pdataf_corr; + double* pdatad_corr; +}; + TEST_CASE("visit_contiguous_buffer") { std::vector valf{1.0, 2.0, 3.0}; std::vector vald{1.0, 2.0, 3.0}; auto pdataf = valf.data(); auto pdatad = vald.data(); - CheckVisitContiguousModel visitor(pdataf, pdatad); + ConstVisitorConstSpan visitor(pdataf, pdatad); + VisitorConstSpan visitor2(pdataf, pdatad); FloatBuffer modelf(std::move(valf)); FloatBuffer modeld(std::move(vald)); @@ -297,11 +348,40 @@ TEST_CASE("visit_contiguous_buffer") { using type_tuple = std::tuple; SECTION("one argument") { - visit_contiguous_buffer(visitor, modelf); - visit_contiguous_buffer(visitor, modeld); + SECTION("call span/ span overloads") { + visit_contiguous_buffer(visitor, modelf); + visit_contiguous_buffer(visitor, std::as_const(modelf)); + visit_contiguous_buffer(visitor, modeld); + visit_contiguous_buffer(visitor, std::as_const(modeld)); + } + + SECTION("call span overload") { + const auto& cmodelf = modelf; + visit_contiguous_buffer(visitor2, cmodelf); + REQUIRE(visitor2.m_called_cfloat); + REQUIRE_FALSE(visitor2.m_called_float); + + const auto& cmodeld = modeld; + visit_contiguous_buffer(visitor2, cmodeld); + REQUIRE(visitor2.m_called_cdouble); + REQUIRE_FALSE(visitor2.m_called_double); + } + + SECTION("calls span overload") { + visit_contiguous_buffer(visitor2, modelf); + REQUIRE_FALSE(visitor2.m_called_cfloat); + REQUIRE(visitor2.m_called_float); + + visit_contiguous_buffer(visitor2, modeld); + REQUIRE_FALSE(visitor2.m_called_cdouble); + REQUIRE(visitor2.m_called_double); + } } SECTION("Two arguments") { visit_contiguous_buffer(visitor, modelf, modeld); + + visit_contiguous_buffer(visitor2, modelf, modeld); + REQUIRE(visitor2.m_called_float_double); } }