@@ -48,12 +48,14 @@ namespace detail {
4848
4949// / is_device_tile specialization for UMTensor
5050template <typename T>
51+ requires TiledArray::detail::is_numeric_v<T>
5152struct is_device_tile <
5253 ::TiledArray::Tensor<T, TiledArray::device_um_allocator<T>>>
5354 : public std::true_type {};
5455
5556// / pre-fetch to device
5657template <typename T>
58+ requires TiledArray::detail::is_numeric_v<T>
5759void to_device (const UMTensor<T> &tensor) {
5860 auto stream = device::stream_for (tensor.range ());
5961 TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(tensor,
@@ -62,6 +64,7 @@ void to_device(const UMTensor<T> &tensor) {
6264
6365// / pre-fetch to host
6466template <typename T>
67+ requires TiledArray::detail::is_numeric_v<T>
6568void to_host (const UMTensor<T> &tensor) {
6669 auto stream = device::stream_for (tensor.range ());
6770 TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(tensor,
@@ -83,6 +86,7 @@ void to_host(const UMTensor<T> &tensor) {
8386// / handle ComplexConjugate handling for scaling functions
8487// / follows the logic in device/btas.h
8588template <typename T, typename Scalar, typename Queue>
89+ requires TiledArray::detail::is_numeric_v<T>
8690void apply_scale_factor (T *data, std::size_t size, const Scalar &factor,
8791 Queue &queue) {
8892 if constexpr (TiledArray::detail::is_blas_numeric_v<Scalar> ||
@@ -111,7 +115,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor,
111115// /
112116
113117template <typename T, typename Scalar>
114- requires TiledArray::detail::is_numeric_v<Scalar>
118+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
115119UMTensor<T> gemm (const UMTensor<T> &left, const UMTensor<T> &right,
116120 Scalar factor,
117121 const TiledArray::math::GemmHelper &gemm_helper) {
@@ -166,7 +170,7 @@ UMTensor<T> gemm(const UMTensor<T> &left, const UMTensor<T> &right,
166170}
167171
168172template <typename T, typename Scalar>
169- requires TiledArray::detail::is_numeric_v<Scalar>
173+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
170174void gemm (UMTensor<T> &result, const UMTensor<T> &left,
171175 const UMTensor<T> &right, Scalar factor,
172176 const TiledArray::math::GemmHelper &gemm_helper) {
@@ -230,6 +234,7 @@ void gemm(UMTensor<T> &result, const UMTensor<T> &left,
230234// /
231235
232236template <typename T>
237+ requires TiledArray::detail::is_numeric_v<T>
233238UMTensor<T> clone (const UMTensor<T> &arg) {
234239 TA_ASSERT (!arg.empty ());
235240
@@ -252,6 +257,7 @@ UMTensor<T> clone(const UMTensor<T> &arg) {
252257// /
253258
254259template <typename T, typename Index>
260+ requires TiledArray::detail::is_numeric_v<T>
255261UMTensor<T> shift (const UMTensor<T> &arg, const Index &bound_shift) {
256262 TA_ASSERT (!arg.empty ());
257263
@@ -276,6 +282,7 @@ UMTensor<T> shift(const UMTensor<T> &arg, const Index &bound_shift) {
276282}
277283
278284template <typename T, typename Index>
285+ requires TiledArray::detail::is_numeric_v<T>
279286UMTensor<T> &shift_to (UMTensor<T> &arg, const Index &bound_shift) {
280287 const_cast <TiledArray::Range &>(arg.range ()).inplace_shift (bound_shift);
281288 return arg;
@@ -286,6 +293,7 @@ UMTensor<T> &shift_to(UMTensor<T> &arg, const Index &bound_shift) {
286293// /
287294
288295template <typename T>
296+ requires TiledArray::detail::is_numeric_v<T>
289297UMTensor<T> permute (const UMTensor<T> &arg,
290298 const TiledArray::Permutation &perm) {
291299 TA_ASSERT (!arg.empty ());
@@ -308,6 +316,7 @@ UMTensor<T> permute(const UMTensor<T> &arg,
308316}
309317
310318template <typename T>
319+ requires TiledArray::detail::is_numeric_v<T>
311320UMTensor<T> permute (const UMTensor<T> &arg,
312321 const TiledArray::BipartitePermutation &perm) {
313322 TA_ASSERT (!arg.empty ());
@@ -320,7 +329,7 @@ UMTensor<T> permute(const UMTensor<T> &arg,
320329// /
321330
322331template <typename T, typename Scalar>
323- requires TiledArray::detail::is_numeric_v<Scalar>
332+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
324333UMTensor<T> scale (const UMTensor<T> &arg, const Scalar factor) {
325334 auto &queue = blasqueue_for (arg.range ());
326335 const auto stream = device::Stream (queue.device (), queue.stream ());
@@ -335,7 +344,7 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor) {
335344}
336345
337346template <typename T, typename Scalar>
338- requires TiledArray::detail::is_numeric_v<Scalar>
347+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
339348UMTensor<T> &scale_to (UMTensor<T> &arg, const Scalar factor) {
340349 auto &queue = blasqueue_for (arg.range ());
341350 const auto stream = device::Stream (queue.device (), queue.stream ());
@@ -352,7 +361,7 @@ UMTensor<T> &scale_to(UMTensor<T> &arg, const Scalar factor) {
352361}
353362
354363template <typename T, typename Scalar, typename Perm>
355- requires TiledArray::detail::is_numeric_v<Scalar> &&
364+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> &&
356365 TiledArray::detail::is_permutation_v<Perm>
357366UMTensor<T> scale (const UMTensor<T> &arg, const Scalar factor,
358367 const Perm &perm) {
@@ -365,18 +374,20 @@ UMTensor<T> scale(const UMTensor<T> &arg, const Scalar factor,
365374// /
366375
367376template <typename T>
377+ requires TiledArray::detail::is_numeric_v<T>
368378UMTensor<T> neg (const UMTensor<T> &arg) {
369379 return scale (arg, T (-1.0 ));
370380}
371381
372382template <typename T, typename Perm>
373- requires TiledArray::detail::is_permutation_v<Perm>
383+ requires TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T>
374384UMTensor<T> neg (const UMTensor<T> &arg, const Perm &perm) {
375385 auto result = neg (arg);
376386 return permute (result, perm);
377387}
378388
379389template <typename T>
390+ requires TiledArray::detail::is_numeric_v<T>
380391UMTensor<T> &neg_to (UMTensor<T> &arg) {
381392 return scale_to (arg, T (-1.0 ));
382393}
@@ -386,6 +397,7 @@ UMTensor<T> &neg_to(UMTensor<T> &arg) {
386397// /
387398
388399template <typename T>
400+ requires TiledArray::detail::is_numeric_v<T>
389401UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
390402 UMTensor<T> result (arg1.range ());
391403
@@ -406,23 +418,23 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
406418}
407419
408420template <typename T, typename Scalar>
409- requires TiledArray::detail::is_numeric_v<Scalar>
421+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
410422UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
411423 const Scalar factor) {
412424 auto result = add (arg1, arg2);
413425 return scale_to (result, factor);
414426}
415427
416428template <typename T, typename Perm>
417- requires TiledArray::detail::is_permutation_v<Perm>
429+ requires TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T>
418430UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
419431 const Perm &perm) {
420432 auto result = add (arg1, arg2);
421433 return permute (result, perm);
422434}
423435
424436template <typename T, typename Scalar, typename Perm>
425- requires TiledArray::detail::is_numeric_v<Scalar> &&
437+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> &&
426438 TiledArray::detail::is_permutation_v<Perm>
427439UMTensor<T> add (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
428440 const Scalar factor, const Perm &perm) {
@@ -435,6 +447,7 @@ UMTensor<T> add(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
435447// /
436448
437449template <typename T>
450+ requires TiledArray::detail::is_numeric_v<T>
438451UMTensor<T> &add_to (UMTensor<T> &result, const UMTensor<T> &arg) {
439452 auto &queue = blasqueue_for (result.range ());
440453 const auto stream = device::Stream (queue.device (), queue.stream ());
@@ -450,7 +463,7 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg) {
450463}
451464
452465template <typename T, typename Scalar>
453- requires TiledArray::detail::is_numeric_v<Scalar>
466+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
454467UMTensor<T> &add_to (UMTensor<T> &result, const UMTensor<T> &arg,
455468 const Scalar factor) {
456469 add_to (result, arg);
@@ -462,6 +475,7 @@ UMTensor<T> &add_to(UMTensor<T> &result, const UMTensor<T> &arg,
462475// /
463476
464477template <typename T>
478+ requires TiledArray::detail::is_numeric_v<T>
465479UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
466480 UMTensor<T> result (arg1.range ());
467481
@@ -482,23 +496,23 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
482496}
483497
484498template <typename T, typename Scalar>
485- requires TiledArray::detail::is_numeric_v<Scalar>
499+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
486500UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
487501 const Scalar factor) {
488502 auto result = subt (arg1, arg2);
489503 return scale_to (result, factor);
490504}
491505
492506template <typename T, typename Perm>
493- requires TiledArray::detail::is_permutation_v<Perm>
507+ requires TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T>
494508UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
495509 const Perm &perm) {
496510 auto result = subt (arg1, arg2);
497511 return permute (result, perm);
498512}
499513
500514template <typename T, typename Scalar, typename Perm>
501- requires TiledArray::detail::is_numeric_v<Scalar> &&
515+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> &&
502516 TiledArray::detail::is_permutation_v<Perm>
503517UMTensor<T> subt (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
504518 const Scalar factor, const Perm &perm) {
@@ -511,6 +525,7 @@ UMTensor<T> subt(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
511525// /
512526
513527template <typename T>
528+ requires TiledArray::detail::is_numeric_v<T>
514529UMTensor<T> &subt_to (UMTensor<T> &result, const UMTensor<T> &arg) {
515530 auto &queue = blasqueue_for (result.range ());
516531 const auto stream = device::Stream (queue.device (), queue.stream ());
@@ -526,7 +541,7 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg) {
526541}
527542
528543template <typename T, typename Scalar>
529- requires TiledArray::detail::is_numeric_v<Scalar>
544+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
530545UMTensor<T> &subt_to (UMTensor<T> &result, const UMTensor<T> &arg,
531546 const Scalar factor) {
532547 subt_to (result, arg);
@@ -538,6 +553,7 @@ UMTensor<T> &subt_to(UMTensor<T> &result, const UMTensor<T> &arg,
538553// /
539554
540555template <typename T>
556+ requires TiledArray::detail::is_numeric_v<T>
541557UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
542558 TA_ASSERT (arg1.size () == arg2.size ());
543559
@@ -557,23 +573,23 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
557573}
558574
559575template <typename T, typename Scalar>
560- requires TiledArray::detail::is_numeric_v<Scalar>
576+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T>
561577UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
562578 const Scalar factor) {
563579 auto result = mult (arg1, arg2);
564580 return scale_to (result, factor);
565581}
566582
567583template <typename T, typename Perm>
568- requires TiledArray::detail::is_permutation_v<Perm>
584+ requires TiledArray::detail::is_permutation_v<Perm> && TiledArray::detail::is_numeric_v<T>
569585UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
570586 const Perm &perm) {
571587 auto result = mult (arg1, arg2);
572588 return permute (result, perm);
573589}
574590
575591template <typename T, typename Scalar, typename Perm>
576- requires TiledArray::detail::is_numeric_v<Scalar> &&
592+ requires TiledArray::detail::is_numeric_v<Scalar> && TiledArray::detail::is_numeric_v<T> &&
577593 TiledArray::detail::is_permutation_v<Perm>
578594UMTensor<T> mult (const UMTensor<T> &arg1, const UMTensor<T> &arg2,
579595 const Scalar factor, const Perm &perm) {
@@ -586,6 +602,7 @@ UMTensor<T> mult(const UMTensor<T> &arg1, const UMTensor<T> &arg2,
586602// /
587603
588604template <typename T>
605+ requires TiledArray::detail::is_numeric_v<T>
589606UMTensor<T> &mult_to (UMTensor<T> &result, const UMTensor<T> &arg) {
590607 auto stream = device::stream_for (result.range ());
591608 TA_ASSERT (result.size () == arg.size ());
@@ -614,6 +631,7 @@ UMTensor<T> &mult_to(UMTensor<T> &result, const UMTensor<T> &arg,
614631// /
615632
616633template <typename T>
634+ requires TiledArray::detail::is_numeric_v<T>
617635T dot (const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
618636 auto &queue = blasqueue_for (arg1.range ());
619637 const auto stream = device::Stream (queue.device (), queue.stream ());
@@ -634,6 +652,7 @@ T dot(const UMTensor<T> &arg1, const UMTensor<T> &arg2) {
634652// /
635653
636654template <typename T>
655+ requires TiledArray::detail::is_numeric_v<T>
637656T squared_norm (const UMTensor<T> &arg) {
638657 auto &queue = blasqueue_for (arg.range ());
639658 const auto stream = device::Stream (queue.device (), queue.stream ());
@@ -649,11 +668,13 @@ T squared_norm(const UMTensor<T> &arg) {
649668}
650669
651670template <typename T>
671+ requires TiledArray::detail::is_numeric_v<T>
652672T norm (const UMTensor<T> &arg) {
653673 return std::sqrt (squared_norm (arg));
654674}
655675
656676template <typename T>
677+ requires TiledArray::detail::is_numeric_v<T>
657678T sum (const UMTensor<T> &arg) {
658679 detail::to_device (arg);
659680 auto stream = device::stream_for (arg.range ());
@@ -664,6 +685,7 @@ T sum(const UMTensor<T> &arg) {
664685}
665686
666687template <typename T>
688+ requires TiledArray::detail::is_numeric_v<T>
667689T product (const UMTensor<T> &arg) {
668690 detail::to_device (arg);
669691 auto stream = device::stream_for (arg.range ());
@@ -674,6 +696,7 @@ T product(const UMTensor<T> &arg) {
674696}
675697
676698template <typename T>
699+ requires TiledArray::detail::is_numeric_v<T>
677700T max (const UMTensor<T> &arg) {
678701 detail::to_device (arg);
679702 auto stream = device::stream_for (arg.range ());
@@ -684,6 +707,7 @@ T max(const UMTensor<T> &arg) {
684707}
685708
686709template <typename T>
710+ requires TiledArray::detail::is_numeric_v<T>
687711T min (const UMTensor<T> &arg) {
688712 detail::to_device (arg);
689713 auto stream = device::stream_for (arg.range ());
@@ -694,6 +718,7 @@ T min(const UMTensor<T> &arg) {
694718}
695719
696720template <typename T>
721+ requires TiledArray::detail::is_numeric_v<T>
697722T abs_max (const UMTensor<T> &arg) {
698723 detail::to_device (arg);
699724 auto stream = device::stream_for (arg.range ());
@@ -704,6 +729,7 @@ T abs_max(const UMTensor<T> &arg) {
704729}
705730
706731template <typename T>
732+ requires TiledArray::detail::is_numeric_v<T>
707733T abs_min (const UMTensor<T> &arg) {
708734 detail::to_device (arg);
709735 auto stream = device::stream_for (arg.range ());
@@ -721,6 +747,7 @@ namespace madness {
721747namespace archive {
722748
723749template <typename Archive, typename T>
750+ requires TiledArray::detail::is_numeric_v<T>
724751struct ArchiveStoreImpl <Archive, TiledArray::UMTensor<T>> {
725752 static inline void store (const Archive &ar,
726753 const TiledArray::UMTensor<T> &t) {
@@ -736,6 +763,7 @@ struct ArchiveStoreImpl<Archive, TiledArray::UMTensor<T>> {
736763};
737764
738765template <typename Archive, typename T>
766+ requires TiledArray::detail::is_numeric_v<T>
739767struct ArchiveLoadImpl <Archive, TiledArray::UMTensor<T>> {
740768 static inline void load (const Archive &ar, TiledArray::UMTensor<T> &t) {
741769 TiledArray::Range range{};
0 commit comments