Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions btas/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,24 @@ namespace btas {
return y; /* automatically called move semantics */
}

/// adds a number to every element
Tensor operator+(const value_type& x) const {
Tensor y = this->clone();
y += x;
return y; /* automatically called move semantics */
}

/// adds a number to every element
Tensor& operator+=(const value_type& x) {
using std::begin;
using std::cbegin;
using std::cend;
std::transform(cbegin(storage_), cend(storage_), begin(storage_), [x](const auto& v) {
return v+x;
});
return *this;
}

/// subtraction assignment
Tensor& operator-=(const Tensor& x) {
using std::begin;
Expand All @@ -603,6 +621,24 @@ namespace btas {
return y; /* automatically called move semantics */
}

/// subtracts a number from every element
Tensor operator-(const value_type& x) const {
Tensor y = this->clone();
y -= x;
return y; /* automatically called move semantics */
}

/// subtracts a number from every element
Tensor& operator-=(const value_type& x) {
using std::begin;
using std::cbegin;
using std::cend;
std::transform(cbegin(storage_), cend(storage_), begin(storage_), [x](const auto& v) {
return v-x;
});
return *this;
}

/// \return bare const pointer to the first element of data_
/// this enables to call BLAS functions
const_pointer data() const {
Expand Down
8 changes: 8 additions & 0 deletions unittest/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ TEST_CASE("Tensor Operations") {
CHECK_NOTHROW( static_cast<double>(T0) == 11);
}

SECTION("Add/subtract constant") {
T3.fill(1.);
auto T3_plus_1 = T3 + 1.;
for (auto x : T3_plus_1) CHECK(x == 2.);
T3_plus_1 -= 1.;
for (auto x : T3_plus_1) CHECK(x == 1.);
}

SECTION("Generate") {
std::vector<double> data(T3.size());
for (auto& x : data) x = rng();
Expand Down