Skip to content

Commit ee2ff1f

Browse files
committed
fixed btas and kronecker delta Tile API
1 parent b166a48 commit ee2ff1f

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

src/TiledArray/external/btas.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,13 @@ inline btas::Tensor<T, Range, Storage>& shift_to(
255255
return arg;
256256
}
257257

258+
template <typename T, typename Range, typename Storage, typename Index>
259+
inline btas::Tensor<T, Range, Storage>&& shift_to(
260+
btas::Tensor<T, Range, Storage>&& arg, const Index& range_shift) {
261+
const_cast<Range&>(arg.range()).inplace_shift(range_shift);
262+
return std::move(arg);
263+
}
264+
258265
/// result[i] = arg1[i] + arg2[i]
259266
template <typename T, typename Range, typename Storage>
260267
inline btas::Tensor<T, Range, Storage> add(
@@ -388,6 +395,16 @@ inline btas::Tensor<T, Range, Storage>& subt_to(
388395
return result;
389396
}
390397

398+
template <typename T, typename Range, typename Storage>
399+
inline btas::Tensor<T, Range, Storage>&& subt_to(
400+
btas::Tensor<T, Range, Storage>&& result,
401+
const btas::Tensor<T, Range, Storage>& arg) {
402+
auto result_view = make_ti(result);
403+
auto arg_view = make_ti(arg);
404+
result_view.subt_to(arg_view);
405+
return std::move(result);
406+
}
407+
391408
template <typename T, typename Range, typename Storage, typename Scalar,
392409
typename std::enable_if<
393410
TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
@@ -400,6 +417,18 @@ inline btas::Tensor<T, Range, Storage>& subt_to(
400417
return result;
401418
}
402419

420+
template <typename T, typename Range, typename Storage, typename Scalar,
421+
typename std::enable_if<
422+
TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
423+
inline btas::Tensor<T, Range, Storage>&& subt_to(
424+
btas::Tensor<T, Range, Storage>&& result,
425+
const btas::Tensor<T, Range, Storage>& arg, const Scalar factor) {
426+
auto result_view = make_ti(result);
427+
auto arg_view = make_ti(arg);
428+
result_view.subt_to(arg_view, factor);
429+
return std::move(result);
430+
}
431+
403432
/// result[i] = arg1[i] * arg2[i]
404433
template <typename T, typename Range, typename Storage>
405434
inline btas::Tensor<T, Range, Storage> mult(
@@ -460,6 +489,16 @@ inline btas::Tensor<T, Range, Storage>& mult_to(
460489
return result;
461490
}
462491

492+
template <typename T, typename Range, typename Storage>
493+
inline btas::Tensor<T, Range, Storage>&& mult_to(
494+
btas::Tensor<T, Range, Storage>&& result,
495+
const btas::Tensor<T, Range, Storage>& arg) {
496+
auto result_view = make_ti(result);
497+
auto arg_view = make_ti(arg);
498+
result_view.mult_to(arg_view);
499+
return std::move(result);
500+
}
501+
463502
/// result[i] *= arg[i] * factor
464503
template <typename T, typename Range, typename Storage, typename Scalar,
465504
typename std::enable_if<
@@ -473,6 +512,18 @@ inline btas::Tensor<T, Range, Storage>& mult_to(
473512
return result;
474513
}
475514

515+
template <typename T, typename Range, typename Storage, typename Scalar,
516+
typename std::enable_if<
517+
TiledArray::detail::is_numeric_v<Scalar>>::type* = nullptr>
518+
inline btas::Tensor<T, Range, Storage>&& mult_to(
519+
btas::Tensor<T, Range, Storage>&& result,
520+
const btas::Tensor<T, Range, Storage>& arg, const Scalar factor) {
521+
auto result_view = make_ti(result);
522+
auto arg_view = make_ti(arg);
523+
result_view.mult_to(arg_view, factor);
524+
return std::move(result);
525+
}
526+
476527
// Generic element-wise binary operations
477528
// ---------------------------------------------
478529

@@ -540,6 +591,14 @@ inline btas::Tensor<T, Range, Storage>& neg_to(
540591
return result;
541592
}
542593

594+
template <typename T, typename Range, typename Storage>
595+
inline btas::Tensor<T, Range, Storage>&& neg_to(
596+
btas::Tensor<T, Range, Storage>&& result) {
597+
auto result_view = make_ti(result);
598+
result_view.neg_to();
599+
return std::move(result);
600+
}
601+
543602
template <typename T, typename Range, typename Storage>
544603
inline btas::Tensor<T, Range, Storage> neg(
545604
const btas::Tensor<T, Range, Storage>& arg) {
@@ -600,6 +659,14 @@ inline btas::Tensor<T, Range, Storage>& conj_to(
600659
return arg;
601660
}
602661

662+
template <typename T, typename Range, typename Storage>
663+
inline btas::Tensor<T, Range, Storage>&& conj_to(
664+
btas::Tensor<T, Range, Storage>&& arg) {
665+
auto arg_view = make_ti(arg);
666+
arg_view.conj_to();
667+
return std::move(arg);
668+
}
669+
603670
template <typename T, typename Range, typename Storage, typename Scalar,
604671
std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
605672
inline btas::Tensor<T, Range, Storage>& conj_to(
@@ -609,6 +676,15 @@ inline btas::Tensor<T, Range, Storage>& conj_to(
609676
return arg;
610677
}
611678

679+
template <typename T, typename Range, typename Storage, typename Scalar,
680+
std::enable_if_t<TiledArray::detail::is_numeric_v<Scalar>>* = nullptr>
681+
inline btas::Tensor<T, Range, Storage>&& conj_to(
682+
btas::Tensor<T, Range, Storage>&& arg, const Scalar factor) {
683+
auto arg_view = make_ti(arg);
684+
arg_view.conj_to(factor);
685+
return std::move(arg);
686+
}
687+
612688
// Generic element-wise unary operations
613689
// ---------------------------------------------
614690

src/TiledArray/special/kronecker_delta.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ Tensor<T> mult(const KroneckerDeltaTile& arg1, const Tensor<T>& arg2,
168168
template <typename T>
169169
Tensor<T>& mult_to(Tensor<T>& result, const KroneckerDeltaTile& arg1) {
170170
abort();
171-
return result;
171+
}
172+
173+
template <typename T>
174+
Tensor<T>&& mult_to(Tensor<T>&& result, const KroneckerDeltaTile& arg1) {
175+
abort();
172176
}
173177

174178
// dense_result[i] = binary(dense_arg1[i], sparse_arg2[i], op)

0 commit comments

Comments
 (0)