Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ working with floating-point (FP) numbers in a type-erased manner.

developer/index
adding_a_new_type
writing_a_visitor

.. toctree::
:maxdepth: 2
Expand Down
42 changes: 42 additions & 0 deletions docs/source/writing_a_visitor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
.. Copyright 2025 NWChemEx-Project
..
.. Licensed under the Apache License, Version 2.0 (the "License");
.. you may not use this file except in compliance with the License.
.. You may obtain a copy of the License at
..
.. http://www.apache.org/licenses/LICENSE-2.0
..
.. Unless required by applicable law or agreed to in writing, software
.. distributed under the License is distributed on an "AS IS" BASIS,
.. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
.. See the License for the specific language governing permissions and
.. limitations under the License.

#################
Writing a Visitor
#################

To avoid assuming a particular set of operations WTF relies on the visitor
pattern. This page should eventually be a tutorial, but for now is just a set
of notes.

Assume that ``visitor`` is an object of type ``T``. The dispatch function will
additionally require a ``std::tuple<FPList...>`` where ``FPList...`` is a list
of the floating-point types you want to be able to support. Then assume we are
trying to dispatch based on the wrapped types of ``N`` ``FloatBuffer`` objects.

- ``T`` must be a callable type, i.e., define ``operator()``. Functions,
lambdas, and functors all satisfy this criterion
- ``operator()`` may be overloaded. If it is overloaded, the usual C++ overload
resolution rules will be used to select the appropriate overload.
- Dispatching considers the type of the floating point buffers and the number of
floating point buffers.
- ``operator()`` overloads must accent ``N`` parameters.
- ``operator()`` must possess an overload capable of supporting any permutation
with replacement of the types in ``FPLists...``.

- The easiest way to satisfy this is to have a templated overload with ``N``
template parameters (one for the type of each positional argument).

- Each ``FloatBuffer`` will be unwrapped to a ``std::span<U>`` object where
``U`` is the type of the wrapped floating point objects.
6 changes: 3 additions & 3 deletions include/wtf/buffer/detail_/contiguous_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ContiguousModel : public BufferHolder {
*
* @throw None No-throw guarantee.
*/
auto span() { return std::span(data(), size()); }
auto span() { return std::span<FloatType>(data(), size()); }

/** @brief Returns the wrapped data as a std::span.
*
Expand All @@ -145,7 +145,7 @@ class ContiguousModel : public BufferHolder {
*
* @throw None No-throw guarantee.
*/
auto span() const { return std::span(data(), size()); }
auto span() const { return std::span<const FloatType>(data(), size()); }

/** @brief Compares the elements in the buffer for exact equality.
*
Expand Down Expand Up @@ -243,7 +243,7 @@ class ContiguousModel : public BufferHolder {
*/
template<typename TupleType, typename Visitor, typename... Args>
auto visit_contiguous_model(Visitor&& visitor, Args&&... args) {
auto lambda = [=](auto&&... inner_args) {
auto lambda = [&](auto&&... inner_args) {
return visitor(inner_args.span()...);
};
return wtf::detail_::dispatch<ContiguousModel, TupleType>(
Expand Down
6 changes: 5 additions & 1 deletion include/wtf/buffer/float_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ class FloatBuffer {
template<typename TupleType, typename Visitor, typename... Args>
friend auto visit_contiguous_buffer(Visitor&& visitor, Args&&... args);

holder_type& holder_() { return *m_pholder_; }

const holder_type& holder_() const { return *m_pholder_; }

template<typename T>
auto& downcast_() {
using model_type = detail_::ContiguousModel<T>;
Expand Down Expand Up @@ -393,7 +397,7 @@ std::span<T> contiguous_buffer_cast(FloatBuffer& buffer) {
template<typename TupleType, typename Visitor, typename... Args>
auto visit_contiguous_buffer(Visitor&& visitor, Args&&... args) {
return detail_::visit_contiguous_model<TupleType>(
std::forward<Visitor>(visitor), *args.m_pholder_...);
std::forward<Visitor>(visitor), args.holder_()...);
}

// -----------------------------------------------------------------------------
Expand Down
90 changes: 85 additions & 5 deletions tests/unit_tests/wtf/buffer/float_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ TEMPLATE_LIST_TEST_CASE("contiguous_buffer_cast", "[wtf]", all_fp_types) {
REQUIRE(span[2] == three);
}

struct CheckVisitContiguousModel {
CheckVisitContiguousModel(float* pdataf, double* pdatad) :
struct ConstVisitorConstSpan {
ConstVisitorConstSpan(float* pdataf, double* pdatad) :
pdataf_corr(pdataf), pdatad_corr(pdatad) {}

auto operator()(std::span<const float> span) const {
Expand Down Expand Up @@ -283,25 +283,105 @@ struct CheckVisitContiguousModel {
double* pdatad_corr;
};

struct VisitorConstSpan {
VisitorConstSpan(float* pdataf, double* pdatad) :
pdataf_corr(pdataf), pdatad_corr(pdatad) {}

template<typename T>
auto operator()(std::span<T> span) {
if constexpr(std::is_same_v<std::remove_const_t<T>, float>) {
REQUIRE(span.data() == pdataf_corr);
m_called_float = true;
} else {
REQUIRE(span.data() == pdatad_corr);
m_called_double = true;
}
}

template<typename T>
auto operator()(std::span<const T> span) {
if constexpr(std::is_same_v<std::remove_const_t<T>, float>) {
REQUIRE(span.data() == pdataf_corr);
m_called_cfloat = true;
} else {
REQUIRE(span.data() == pdatad_corr);
m_called_cdouble = true;
}
REQUIRE(span.size() == 3);
}

auto operator()(std::span<float> lhs, std::span<double> rhs) {
REQUIRE(lhs.data() == pdataf_corr);
REQUIRE(lhs.size() == 3);
REQUIRE(rhs.data() == pdatad_corr);
REQUIRE(rhs.size() == 3);
m_called_float_double = true;
}

template<typename T, typename U>
auto operator()(std::span<T> lhs, std::span<U> rhs) const {
throw std::runtime_error("Only float, double supported");
}

bool m_called_float = false;
bool m_called_cfloat = false;
bool m_called_double = false;
bool m_called_cdouble = false;
bool m_called_float_double = false;

float* pdataf_corr;
double* pdatad_corr;
};

TEST_CASE("visit_contiguous_buffer") {
std::vector<float> valf{1.0, 2.0, 3.0};
std::vector<double> vald{1.0, 2.0, 3.0};
auto pdataf = valf.data();
auto pdatad = vald.data();

CheckVisitContiguousModel visitor(pdataf, pdatad);
ConstVisitorConstSpan visitor(pdataf, pdatad);
VisitorConstSpan visitor2(pdataf, pdatad);

FloatBuffer modelf(std::move(valf));
FloatBuffer modeld(std::move(vald));

using type_tuple = std::tuple<float, double>;

SECTION("one argument") {
visit_contiguous_buffer<type_tuple>(visitor, modelf);
visit_contiguous_buffer<type_tuple>(visitor, modeld);
SECTION("call span<const float>/ span<const double> overloads") {
visit_contiguous_buffer<type_tuple>(visitor, modelf);
visit_contiguous_buffer<type_tuple>(visitor, std::as_const(modelf));
visit_contiguous_buffer<type_tuple>(visitor, modeld);
visit_contiguous_buffer<type_tuple>(visitor, std::as_const(modeld));
}

SECTION("call span<const T> overload") {
const auto& cmodelf = modelf;
visit_contiguous_buffer<type_tuple>(visitor2, cmodelf);
REQUIRE(visitor2.m_called_cfloat);
REQUIRE_FALSE(visitor2.m_called_float);

const auto& cmodeld = modeld;
visit_contiguous_buffer<type_tuple>(visitor2, cmodeld);
REQUIRE(visitor2.m_called_cdouble);
REQUIRE_FALSE(visitor2.m_called_double);
}

SECTION("calls span<T> overload") {
visit_contiguous_buffer<type_tuple>(visitor2, modelf);
REQUIRE_FALSE(visitor2.m_called_cfloat);
REQUIRE(visitor2.m_called_float);

visit_contiguous_buffer<type_tuple>(visitor2, modeld);
REQUIRE_FALSE(visitor2.m_called_cdouble);
REQUIRE(visitor2.m_called_double);
}
}

SECTION("Two arguments") {
visit_contiguous_buffer<type_tuple>(visitor, modelf, modeld);

visit_contiguous_buffer<type_tuple>(visitor2, modelf, modeld);
REQUIRE(visitor2.m_called_float_double);
}
}