Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tensorwrapper/allocator/allocator_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class AllocatorBase : public detail_::PolymorphicBase<AllocatorBase> {
/// Type all buffers derive from
using buffer_base_type = buffer::BufferBase;

/// Type of a mutable reference to an object of type buffer_base_type
using buffer_base_reference = buffer_base_type&;

/// Type of a read-only reference to an object of type buffer_base_type
using const_buffer_base_reference = const buffer_base_type&;

/// Type of a pointer to an object of type buffer_base_type
using buffer_base_pointer = typename buffer_base_type::buffer_base_pointer;

Expand Down
57 changes: 56 additions & 1 deletion include/tensorwrapper/allocator/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#pragma once
#include <tensorwrapper/allocator/replicated.hpp>
#include <tensorwrapper/buffer/eigen.hpp>
#include <tensorwrapper/buffer/buffer_fwd.hpp>

namespace tensorwrapper::allocator {

Expand All @@ -42,13 +42,21 @@ class Eigen : public Replicated {
// Pull in base class's types
using my_base_type::base_pointer;
using my_base_type::buffer_base_pointer;
using my_base_type::buffer_base_reference;
using my_base_type::const_base_reference;
using my_base_type::const_buffer_base_reference;
using my_base_type::layout_pointer;
using my_base_type::runtime_view_type;

/// Type of a buffer containing an Eigen tensor
using eigen_buffer_type = buffer::Eigen<FloatType, Rank>;

/// Type of a mutable reference to an object of type eigen_buffer_type
using eigen_buffer_reference = eigen_buffer_type&;

/// Type of a read-only reference to an object of type eigen_buffer_type
using const_eigen_buffer_reference = const eigen_buffer_type&;

/// Type of a pointer to an eigen_buffer_type object
using eigen_buffer_pointer = std::unique_ptr<eigen_buffer_type>;

Expand Down Expand Up @@ -140,6 +148,53 @@ class Eigen : public Replicated {
return pbuffer;
}

/** @brief Determines if @p buffer can be rebound as an Eigen buffer.
*
* Rebinding a buffer allows the same memory to be viewed as a (possibly)
* different type of buffer.
*
* @param[in] buffer The tensor we are attempting to rebind.
*
* @return True if @p buffer can be rebound to the type of buffer
* associated with this allocator and false otherwise.
*
* @throw None No throw guarantee
*/
static bool can_rebind(const_buffer_base_reference buffer);

/** @brief Rebinds a buffer to the same type as *this.
*
* This method will convert @p buffer into a buffer which could have been
* allocated by *this. If @p buffer was allocated as such a buffer already,
* then this method is simply a downcast.
*
* @param[in] buffer The buffer to rebind.
*
* @return A mutable reference to @p buffer viewed as a buffer that could
* have been allocated by *this.
*
* @throw std::runtime_error if can_rebind(buffer) is false. Strong throw
* guarantee.
*/
static eigen_buffer_reference rebind(buffer_base_reference buffer);

/** @brief Rebinds a buffer to the same type as *this.
*
* This method is the same as the non-const version except that the result
* is read-only. See the description for the non-const version for more
* details.
*
* @param[in] buffer The buffer to rebind.
*
* @return A read-only reference to @p buffer viewed as if it was
* allocated by *this.
*
* @throw std::runtime_error if can_rebind(buffer) is false. Strong throw
* guarantee.
*/
static const_eigen_buffer_reference rebind(
const_buffer_base_reference buffer);

/** @brief Is *this value equal to @p rhs?
*
* @tparam FloatType2 The numerical type @p rhs uses for its elements.
Expand Down
150 changes: 150 additions & 0 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#pragma once
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/dsl/labeled.hpp>
#include <tensorwrapper/layout/layout_base.hpp>

namespace tensorwrapper::buffer {

/** @brief Common base class for all buffer objects.
Expand All @@ -35,13 +37,19 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
/// Type all buffers inherit from
using buffer_base_type = typename my_base_type::base_type;

/// Type of a mutable reference to a buffer_base_type object
using buffer_base_reference = typename my_base_type::base_reference;

/// Type of a read-only reference to a buffer_base_type object
using const_buffer_base_reference =
typename my_base_type::const_base_reference;

/// Type of a pointer to an object of type buffer_base_type
using buffer_base_pointer = typename my_base_type::base_pointer;

/// Type of a pointer to a read-only object of type buffer_base_type
using const_buffer_base_pointer = typename my_base_type::const_base_pointer;

/// Type of the class describing the physical layout of the buffer
using layout_type = layout::LayoutBase;

Expand All @@ -51,6 +59,18 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
/// Type of a pointer to the layout
using layout_pointer = typename layout_type::layout_pointer;

/// Type of labels for making a labeled buffer
using label_type = std::string;

/// Type of a labeled buffer
using labeled_buffer_type = dsl::Labeled<buffer_base_type, label_type>;

/// Type of a labeled read-only buffer (n.b. labels are mutable)
using labeled_const_buffer_type = dsl::Labeled<const buffer_base_type>;

/// Type of a read-only reference to a labeled_buffer_type object
using const_labeled_buffer_reference = const labeled_const_buffer_type&;

// -------------------------------------------------------------------------
// -- Accessors
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -82,10 +102,128 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *m_layout_;
}

// -------------------------------------------------------------------------
// -- BLAS Operations
// -------------------------------------------------------------------------

/** @brief Set this to the result of *this + rhs.
*
* This method will overwrite the state of *this with the result of
* adding the original state of *this to that of @p rhs. Depending on the
* value @p this_labels compared to the labels associated with @p rhs,
* it may be a permutation of @p rhs that is added to *this.
*
* @param[in] this_labels The labels to associate with the modes of *this.
* @param[in] rhs The buffer to add into *this.
*
* @throws ??? Throws if the derived class's implementation throws. Same
* throw guarantee.
*/
buffer_base_reference addition_assignment(
label_type this_labels, const_labeled_buffer_reference rhs) {
return addition_assignment_(std::move(this_labels), rhs);
}

/** @brief Returns the result of *this + rhs.
*
* This method is the same as addition_assignment except that the result
* is returned in a newly allocated buffer instead of overwriting *this.
*
* @param[in] this_labels the labels for the modes of *this.
* @param[in] rhs The buffer to add to *this.
*
* @return The buffer resulting from adding *this to @p rhs.
*
* @throw std::bad_alloc if there is a problem copying *this. Strong throw
* guarantee.
* @throw ??? If addition_assignment throws when adding @p rhs to the
* copy of *this. Same throw guarantee.
*/
buffer_base_pointer addition(label_type this_labels,
const_labeled_buffer_reference rhs) const {
auto pthis = clone();
pthis->addition_assignment(std::move(this_labels), rhs);
return pthis;
}

/** @brief Sets *this to a permutation of @p rhs.
*
* `rhs.rhs()` are the dummy indices associated with the modes of the
* buffer in @p rhs and @p this_labels are the dummy indices associated
* with the buffer in *this. This method will permute @p rhs so that the
* resulting buffer's modes are ordered consistently with @p this_labels,
* i.e. the permutation is FROM the `rhs.rhs()` order TO the
* @p this_labels order. This is seemingly backwards when described out,
* but consistent with the intent of a DSL expression like
* `t("i,j") = x("j,i");` where the intent is to set `t` equal to the
* transpose of `x`.
*
* @param[in] this_labels the dummy indices for the modes of *this.
* @param[in] rhs The tensor to permute.
*
* @return *this after setting it equal to a permutation of @p rhs.
*
* @throw ??? If the derived class's implementation of permute_assignment_
* throws. Same throw guarantee.
*/
buffer_base_reference permute_assignment(
label_type this_labels, const_labeled_buffer_reference rhs) {
return permute_assignment_(std::move(this_labels), rhs);
}

/** @brief Returns a copy of *this obtained by permuting *this.
*
* This method simply calls permute_assignment on a copy of *this. See the
* description of permute_assignment for more details.
*
* @param[in] this_labels dummy indices representing the modes of *this in
* its current state.
* @param[in] out_labels how the user wants the modes of *this to be
* ordered.
*
* @throw std::bad_alloc if there is a problem allocating the copy. Strong
* throw guarantee.
* @throw ??? If the derived class's implementation of permute_assignment_
* throws. Same throw guarantee.
*/
buffer_base_pointer permute(label_type this_labels,
label_type out_labels) const {
auto pthis = clone();
pthis->permute_assignment(std::move(out_labels), (*this)(this_labels));
return pthis;
}

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------

/** @brief Associates labels with the modes of *this.
*
* This method is used to create a labeled buffer object by pairing *this
* with the provided labels. The resulting object is capable of being
* composed via the DSL.
*
* @param[in] labels The indices to associate with the modes of *this.
*
* @return A DSL term pairing *this with @p labels.
*
* @throw None No throw guarantee.
*/
labeled_buffer_type operator()(label_type labels);

/** @brief Associates labels with the modes of *this.
*
* This method is the same as the non-const version except that the result
* contains a read-only reference to *this.
*
* @param[in] labels The labels to associate with *this.
*
* @return A DSL term pairing *this with @p labels.
*
* @throw None No throw guarantee.
*/
labeled_const_buffer_type operator()(label_type labels) const;

/** @brief Is *this value equal to @p rhs?
*
* Two BufferBase objects are value equal if the layouts they contain are
Expand Down Expand Up @@ -183,6 +321,18 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *this;
}

/// Derived class should overwrite to implement addition_assignment
virtual buffer_base_reference addition_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) {
throw std::runtime_error("Addition assignment NYI");
}

/// Derived class should overwrite to implement permute_assignment
virtual buffer_base_reference permute_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) {
throw std::runtime_error("Permute assignment NYI");
}

private:
/// Throws std::runtime_error when there is no layout
void assert_layout_() const {
Expand Down
30 changes: 30 additions & 0 deletions include/tensorwrapper/buffer/buffer_fwd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace tensorwrapper::buffer {

class BufferBase;

template<typename FloatType, unsigned short Rank>
class Eigen;

class Local;

class Replicated;

} // namespace tensorwrapper::buffer
31 changes: 31 additions & 0 deletions include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class Eigen : public Replicated {
/// Pull in base class's types
using typename my_base_type::buffer_base_pointer;
using typename my_base_type::const_buffer_base_reference;
using typename my_base_type::const_labeled_buffer_reference;
using typename my_base_type::const_layout_reference;
using typename my_base_type::label_type;

/// Type of a rank @p Rank tensor using floats of type @p FloatType
using data_type = eigen::data_type<FloatType, Rank>;
Expand Down Expand Up @@ -180,9 +182,38 @@ class Eigen : public Replicated {
return my_base_type::are_equal_impl_<my_type>(rhs);
}

/// Implements addition_assignment by rebinding rhs to an Eigen buffer
buffer_base_reference addition_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) override;

/// Implements permute assignment by deferring to Eigen's shuffle command.
buffer_base_reference permute_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) override;

/// Implements to_string
typename my_base_type::string_type to_string_() const override;

private:
/// The actual Eigen tensor
data_type m_tensor_;
};

#define DECLARE_EIGEN_BUFFER(RANK) \
extern template class Eigen<float, RANK>; \
extern template class Eigen<double, RANK>

DECLARE_EIGEN_BUFFER(0);
DECLARE_EIGEN_BUFFER(1);
DECLARE_EIGEN_BUFFER(2);
DECLARE_EIGEN_BUFFER(3);
DECLARE_EIGEN_BUFFER(4);
DECLARE_EIGEN_BUFFER(5);
DECLARE_EIGEN_BUFFER(6);
DECLARE_EIGEN_BUFFER(7);
DECLARE_EIGEN_BUFFER(8);
DECLARE_EIGEN_BUFFER(9);
DECLARE_EIGEN_BUFFER(10);

#undef DECLARE_EIGEN_BUFFER

} // namespace tensorwrapper::buffer
Loading
Loading