@@ -439,16 +439,67 @@ namespace xt
439
439
using requested_value_type = detail::conditional_promote_to_complex_t <e1_value_type, e2_requested_value_type>;
440
440
};
441
441
442
+ /* *********************************
443
+ * Expression Order Optimizations *
444
+ **********************************/
445
+
446
+ class optimize_expression
447
+ {
448
+ private:
449
+ template <class E1 , class E2 >
450
+ struct equal_rank
451
+ {
452
+ static constexpr bool value = get_rank<E1 >::value == get_rank<E2 >::value;
453
+ };
454
+
455
+ template <class E1 , class ... E>
456
+ struct all_equal_rank
457
+ {
458
+ static constexpr bool value = xtl::conjunction<equal_rank<E1 , E>...>::value
459
+ && (get_rank<E1 >::value != SIZE_MAX);
460
+ };
461
+
462
+ template <class F , class ... CT, class ... S, size_t ... I, size_t ... J>
463
+ inline auto impl_reorder_function (const xfunction<F, CT...>& e, std::tuple<S...> slices, std::index_sequence<I...>, std::index_sequence<J...>)
464
+ {
465
+ return make_lambda_xfunction (F (), view (std::get<I>(e.arguments ()), std::get<J>(slices)...)...);
466
+ }
467
+
468
+ public:
469
+ // when we have a view of a function where the closures of the functions are of equal rank (i.e no broadcasting)
470
+ // we can flip the order of the function and the view such that we have a function of views of containers which
471
+ // can be linearly assigned unlike the inverse.
472
+ template <class F , class ... CT, class ... S, class = std::enable_if_t <all_equal_rank<std::decay_t <CT>...>::value>>
473
+ inline auto reorder (const xview<xfunction<F, CT...>, S...>& e)
474
+ {
475
+ return impl_reorder_function (
476
+ e.expression (),
477
+ e.slices (),
478
+ std::make_index_sequence<sizeof ...(CT)>(),
479
+ std::make_index_sequence<sizeof ...(S)>()
480
+ );
481
+ }
482
+
483
+ // base case no applicable optimization
484
+ template <class E >
485
+ inline auto & reorder (E&& e)
486
+ {
487
+ return std::forward<E>(e);
488
+ }
489
+ };
490
+
442
491
template <class E1 , class E2 >
443
492
inline void xexpression_assigner_base<xtensor_expression_tag>::assign_data(
444
493
xexpression<E1 >& e1 ,
445
494
const xexpression<E2 >& e2 ,
446
495
bool trivial
447
496
)
448
497
{
449
- E1 & de1 = e1 .derived_cast ();
450
- const E2 & de2 = e2 .derived_cast ();
451
- using traits = xassign_traits<E1 , E2 >;
498
+ auto & de1 = e1 .derived_cast ();
499
+ const auto & de2 = optimize_expression ().reorder (e2 .derived_cast ());
500
+ using dst_type = typename std::decay_t <decltype (de1)>;
501
+ using src_type = typename std::decay_t <decltype (de2)>;
502
+ using traits = xassign_traits<dst_type, src_type>;
452
503
453
504
bool linear_assign = traits::linear_assign (de1, de2, trivial);
454
505
constexpr bool simd_assign = traits::simd_assign ();
0 commit comments