diff --git a/btas/tensor.h b/btas/tensor.h index 106098da..85eb50e8 100644 --- a/btas/tensor.h +++ b/btas/tensor.h @@ -585,6 +585,24 @@ namespace btas { return y; /* automatically called move semantics */ } + /// adds a number to every element + Tensor operator+(const value_type& x) const { + Tensor y = this->clone(); + y += x; + return y; /* automatically called move semantics */ + } + + /// adds a number to every element + Tensor& operator+=(const value_type& x) { + using std::begin; + using std::cbegin; + using std::cend; + std::transform(cbegin(storage_), cend(storage_), begin(storage_), [x](const auto& v) { + return v+x; + }); + return *this; + } + /// subtraction assignment Tensor& operator-=(const Tensor& x) { using std::begin; @@ -603,6 +621,24 @@ namespace btas { return y; /* automatically called move semantics */ } + /// subtracts a number from every element + Tensor operator-(const value_type& x) const { + Tensor y = this->clone(); + y -= x; + return y; /* automatically called move semantics */ + } + + /// subtracts a number from every element + Tensor& operator-=(const value_type& x) { + using std::begin; + using std::cbegin; + using std::cend; + std::transform(cbegin(storage_), cend(storage_), begin(storage_), [x](const auto& v) { + return v-x; + }); + return *this; + } + /// \return bare const pointer to the first element of data_ /// this enables to call BLAS functions const_pointer data() const { diff --git a/unittest/tensor_test.cc b/unittest/tensor_test.cc index 95bae874..d6b6e46a 100644 --- a/unittest/tensor_test.cc +++ b/unittest/tensor_test.cc @@ -219,6 +219,14 @@ TEST_CASE("Tensor Operations") { CHECK_NOTHROW( static_cast(T0) == 11); } + SECTION("Add/subtract constant") { + T3.fill(1.); + auto T3_plus_1 = T3 + 1.; + for (auto x : T3_plus_1) CHECK(x == 2.); + T3_plus_1 -= 1.; + for (auto x : T3_plus_1) CHECK(x == 1.); + } + SECTION("Generate") { std::vector data(T3.size()); for (auto& x : data) x = rng();