Skip to content

Commit a8f1a6a

Browse files
authored
Merge pull request #2238 from tdegeus/argsort
argsort: catching zeros stride leading axis (bugfix)
2 parents 1b36ad6 + fcf6ade commit a8f1a6a

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

include/xtensor/xsort.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,15 @@ namespace xt
279279
{
280280
n_iters = std::accumulate(data.shape().begin(), data.shape().end() - 1,
281281
std::size_t(1), std::multiplies<>());
282-
data_secondary_stride = data.strides()[data.dimension() - 2];
283-
inds_secondary_stride = inds.strides()[inds.dimension() - 2];
282+
data_secondary_stride = data.shape(data.dimension() - 1);
283+
inds_secondary_stride = inds.shape(inds.dimension() - 1);
284284
}
285285
else
286286
{
287287
n_iters = std::accumulate(data.shape().begin() + 1, data.shape().end(),
288288
std::size_t(1), std::multiplies<>());
289-
data_secondary_stride = data.strides()[1];
290-
inds_secondary_stride = inds.strides()[1];
289+
data_secondary_stride = data.shape(0);
290+
inds_secondary_stride = inds.shape(0);
291291
}
292292

293293
auto ptr = data.data();

test/test_xsort.cpp

+43-1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,48 @@ namespace xt
7171
}
7272
}
7373

74+
TEST(xsort, argsort_zero_stride)
75+
{
76+
{
77+
xt::xtensor<double, 2> A = {{1.4, 1.3, 1.2, 1.1}};
78+
xt::xtensor<size_t, 2> bsort = {{0, 0, 0, 0}};
79+
xt::xtensor<size_t, 2> fsort = {{3, 2, 1, 0}};
80+
EXPECT_EQ(bsort, xt::argsort(A, 0));
81+
EXPECT_EQ(fsort, xt::argsort(A, 1));
82+
EXPECT_EQ(fsort, xt::argsort(A));
83+
}
84+
{
85+
xt::xtensor<double, 3> A = {{{1.4, 1.3, 1.2, 1.1}}};
86+
xt::xtensor<size_t, 3> bsort = {{{0, 0, 0, 0}}};
87+
xt::xtensor<size_t, 3> fsort = {{{3, 2, 1, 0}}};
88+
EXPECT_EQ(bsort, xt::argsort(A, 0));
89+
EXPECT_EQ(bsort, xt::argsort(A, 1));
90+
EXPECT_EQ(fsort, xt::argsort(A, 2));
91+
EXPECT_EQ(fsort, xt::argsort(A));
92+
}
93+
}
94+
95+
TEST(xsort, argsort_zero_stride_column_major)
96+
{
97+
{
98+
xt::xtensor<double, 2, xt::layout_type::column_major> A = {{1.4, 1.3, 1.2, 1.1}};
99+
xt::xtensor<size_t, 2, xt::layout_type::column_major> bsort = {{0, 0, 0, 0}};
100+
xt::xtensor<size_t, 2, xt::layout_type::column_major> fsort = {{3, 2, 1, 0}};
101+
EXPECT_EQ(bsort, xt::argsort(A, 0));
102+
EXPECT_EQ(fsort, xt::argsort(A, 1));
103+
EXPECT_EQ(fsort, xt::argsort(A));
104+
}
105+
{
106+
xt::xtensor<double, 3, xt::layout_type::column_major> A = {{{1.4, 1.3, 1.2, 1.1}}};
107+
xt::xtensor<size_t, 3, xt::layout_type::column_major> bsort = {{{0, 0, 0, 0}}};
108+
xt::xtensor<size_t, 3, xt::layout_type::column_major> fsort = {{{3, 2, 1, 0}}};
109+
EXPECT_EQ(bsort, xt::argsort(A, 0));
110+
EXPECT_EQ(bsort, xt::argsort(A, 1));
111+
EXPECT_EQ(fsort, xt::argsort(A, 2));
112+
EXPECT_EQ(fsort, xt::argsort(A));
113+
}
114+
}
115+
74116
TEST(xsort, flatten_argsort)
75117
{
76118
{
@@ -256,7 +298,7 @@ namespace xt
256298
auto d = unique(c);
257299
EXPECT_EQ(d, bx);
258300

259-
auto e = xt::unique(xt::where(xt::greater(b,2), 1, 0));
301+
auto e = xt::unique(xt::where(xt::greater(b,2), 1, 0));
260302
xarray<double> ex = {0, 1};
261303
EXPECT_EQ(e, ex);
262304
}

0 commit comments

Comments
 (0)