Skip to content

Commit 03774e5

Browse files
committed
Moved reorder logic to assignment
1 parent ae52796 commit 03774e5

File tree

1 file changed

+54
-3
lines changed

1 file changed

+54
-3
lines changed

include/xtensor/xassign.hpp

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,16 +439,67 @@ namespace xt
439439
using requested_value_type = detail::conditional_promote_to_complex_t<e1_value_type, e2_requested_value_type>;
440440
};
441441

442+
/**********************************
443+
* Expression Order Optimizations *
444+
**********************************/
445+
446+
class optimize_expression
447+
{
448+
private:
449+
template <class E1, class E2>
450+
struct equal_rank
451+
{
452+
static constexpr bool value = get_rank<E1>::value == get_rank<E2>::value;
453+
};
454+
455+
template <class E1, class... E>
456+
struct all_equal_rank
457+
{
458+
static constexpr bool value = xtl::conjunction<equal_rank<E1, E>...>::value
459+
&& (get_rank<E1>::value != SIZE_MAX);
460+
};
461+
462+
template <class F, class... CT, class... S, size_t... I, size_t... J>
463+
inline auto impl_reorder_function(const xfunction<F, CT...>& e, std::tuple<S...> slices, std::index_sequence<I...>, std::index_sequence<J...>)
464+
{
465+
return make_lambda_xfunction(F(), view(std::get<I>(e.arguments()), std::get<J>(slices)...)...);
466+
}
467+
468+
public:
469+
//when we have a view of a function where the closures of the functions are of equal rank (i.e no broadcasting)
470+
//we can flip the order of the function and the view such that we have a function of views of containers which
471+
//can be linearly assigned unlike the inverse.
472+
template <class F, class... CT, class... S, class = std::enable_if_t<all_equal_rank<std::decay_t<CT>...>::value>>
473+
inline auto reorder(const xview<xfunction<F, CT...>, S...>& e)
474+
{
475+
return impl_reorder_function(
476+
e.expression(),
477+
e.slices(),
478+
std::make_index_sequence<sizeof...(CT)>(),
479+
std::make_index_sequence<sizeof...(S)>()
480+
);
481+
}
482+
483+
//base case no applicable optimization
484+
template<class E>
485+
inline auto& reorder(E&& e)
486+
{
487+
return std::forward<E>(e);
488+
}
489+
};
490+
442491
template <class E1, class E2>
443492
inline void xexpression_assigner_base<xtensor_expression_tag>::assign_data(
444493
xexpression<E1>& e1,
445494
const xexpression<E2>& e2,
446495
bool trivial
447496
)
448497
{
449-
E1& de1 = e1.derived_cast();
450-
const E2& de2 = e2.derived_cast();
451-
using traits = xassign_traits<E1, E2>;
498+
auto& de1 = e1.derived_cast();
499+
const auto& de2 = optimize_expression().reorder(e2.derived_cast());
500+
using dst_type = typename std::decay_t<decltype(de1)>;
501+
using src_type = typename std::decay_t<decltype(de2)>;
502+
using traits = xassign_traits<dst_type, src_type>;
452503

453504
bool linear_assign = traits::linear_assign(de1, de2, trivial);
454505
constexpr bool simd_assign = traits::simd_assign();

0 commit comments

Comments
 (0)