Skip to content

Commit

Permalink
Merge pull request #2428 from cambridgeconsultants:cc_up_masked_ops
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722985349
  • Loading branch information
copybara-github committed Feb 4, 2025
2 parents f129551 + af4183d commit 326dba3
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 0 deletions.
21 changes: 21 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,9 @@ types, and on SVE/RVV.

* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`.

* <code>V **MaskedOr**(M m, V a, V b)</code>: returns `a[i] | b[i]`
or `zero` if `m[i]` is false.

The following three-argument functions may be more efficient than assembling
them from 2-argument functions:

Expand Down Expand Up @@ -2491,6 +2494,24 @@ more efficient on some targets.
* <code>T **ReduceMin**(D, V v)</code>: returns the minimum of all lanes.
* <code>T **ReduceMax**(D, V v)</code>: returns the maximum of all lanes.
### Masked reductions
**Note**: Horizontal operations (across lanes of the same vector) such as
reductions are slower than normal SIMD operations and are typically used outside
critical loops.
All ops in this section ignore lanes where `mask=false`. These are equivalent
to, and potentially more efficient than, `GetLane(SumOfLanes(d,
IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask
elements are false.
* <code>T **MaskedReduceSum**(D, M m, V v)</code>: returns the sum of all lanes
where `m[i]` is `true`.
* <code>T **MaskedReduceMin**(D, M m, V v)</code>: returns the minimum of all
lanes where `m[i]` is `true`.
* <code>T **MaskedReduceMax**(D, M m, V v)</code>: returns the maximum of all
lanes where `m[i]` is `true`.
### Crypto
Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
Expand Down
47 changes: 47 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_x(m, a, b); \
}
// User-specified mask. Mask=false value is zero.
#define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_z(m, a, b); \
}

#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
Expand Down Expand Up @@ -763,6 +769,9 @@ HWY_API V Or(const V a, const V b) {
return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
}

// ------------------------------ MaskedOr
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedOr, orr)

// ------------------------------ Xor

namespace detail {
Expand Down Expand Up @@ -1678,6 +1687,7 @@ namespace detail {
return sv##OP##_##CHAR##BITS(pg, v); \
}

// TODO: Remove SumOfLanesM in favor of using MaskedReduceSum
HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv)
HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv)

Expand Down Expand Up @@ -1725,6 +1735,25 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
return detail::MaxOfLanesM(detail::MakeMask(d), v);
}

#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
#else
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
HWY_API TFromD<D> MaskedReduceSum(D /*d*/, M m, VFromD<D> v) {
return detail::SumOfLanesM(m, v);
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D /*d*/, M m, VFromD<D> v) {
return detail::MinOfLanesM(m, v);
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D /*d*/, M m, VFromD<D> v) {
return detail::MaxOfLanesM(m, v);
}

// ------------------------------ SumOfLanes

template <class D, HWY_IF_LANES_GT_D(D, 1)>
Expand Down Expand Up @@ -5056,6 +5085,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) {
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float");
return IfThenElse(IsNegative(v), yes, no);
}
// ------------------------------ IfNegativeThenNegOrUndefIfZero

#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#else
#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
#endif

#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \
}

HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg)

#undef HWY_SVE_NEG_IF

// ------------------------------ AverageRound (ShiftRight)

Expand Down Expand Up @@ -6610,6 +6656,7 @@ HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount,
#undef HWY_SVE_IF_NOT_EMULATED_D
#undef HWY_SVE_PTRUE
#undef HWY_SVE_RETV_ARGMVV
#undef HWY_SVE_RETV_ARGMVV_Z
#undef HWY_SVE_RETV_ARGMV_Z
#undef HWY_SVE_RETV_ARGMV
#undef HWY_SVE_RETV_ARGPV
Expand Down
26 changes: 26 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,28 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
}
#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8

#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
#else
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
#endif

template <class D, class M>
HWY_API TFromD<D> MaskedReduceSum(D d, M m, VFromD<D> v) {
return ReduceSum(d, IfThenElseZero(m, v));
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMin(D d, M m, VFromD<D> v) {
return ReduceMin(d, IfThenElse(m, v, Set(d, hwy::PositiveInfOrHighestValue <TFromD<D>>())));
}
template <class D, class M>
HWY_API TFromD<D> MaskedReduceMax(D d, M m, VFromD<D> v) {
return ReduceMax(d, IfThenElse(m, v, Set(d, hwy::NegativeInfOrLowestValue<TFromD<D>>())));
}

#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR

// ------------------------------ IsEitherNaN
#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_IS_EITHER_NAN
Expand Down Expand Up @@ -7568,6 +7590,10 @@ HWY_API V BitShuffle(V v, VI idx) {

#endif // HWY_NATIVE_BITSHUFFLE

template <class V, class M>
HWY_API V MaskedOr(M m, V a, V b) {
return IfThenElseZero(m, Or(a, b));
}
// ------------------------------ AllBits1/AllBits0
#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_ALLONES
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/rvv-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4755,6 +4755,8 @@ HWY_API T ReduceMax(D d, const VFromD<D> v) {

#undef HWY_RVV_REDUCE

// TODO: add MaskedReduceSum/Min/Max

// ------------------------------ SumOfLanes

template <class D, HWY_IF_LANES_GT_D(D, 1)>
Expand Down
23 changes: 23 additions & 0 deletions hwy/tests/logical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,28 @@ HWY_NOINLINE void TestAllTestBit() {
ForIntegerTypes(ForPartialVectors<TestTestBit>());
}

struct TestMaskedOr {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const MFromD<D> all_true = MaskTrue(d);
const auto v1 = Iota(d, 1);
const auto v2 = Iota(d, 2);

HWY_ASSERT_VEC_EQ(d, Or(v2, v1), MaskedOr(all_true, v1, v2));

const MFromD<D> first_five = FirstN(d, 5);
const Vec<D> v0 = Zero(d);

const Vec<D> v1_exp = IfThenElse(first_five, Or(v2, v1), v0);

HWY_ASSERT_VEC_EQ(d, v1_exp, MaskedOr(first_five, v1, v2));
}
};

HWY_NOINLINE void TestAllMaskedLogical() {
ForAllTypes(ForPartialVectors<TestMaskedOr>());
}

struct TestAllBits {
template <class T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
Expand Down Expand Up @@ -185,6 +207,7 @@ HWY_BEFORE_TEST(HwyLogicalTest);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllMaskedLogical);
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllBits);

HWY_AFTER_TEST();
Expand Down
126 changes: 126 additions & 0 deletions hwy/tests/reduction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,128 @@ HWY_NOINLINE void TestAllSumsOf8() {
ForGEVectors<64, TestSumsOf8>()(uint8_t());
}

struct TestMaskedReduceSum {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

using TI = MakeSigned<T>;
const Rebind<TI, D> di;
const Vec<D> v2 = Iota(d, 2);

const size_t N = Lanes(d);
auto bool_lanes = AllocateAligned<TI>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
T expected = 0;
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
if (bool_lanes[i]) {
expected += ConvertScalarTo<T>(i + 2);
}
}

const auto mask_i = Load(di, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));

// If all elements are disabled the result is implementation defined
if (AllFalse(d, mask)) {
continue;
}

HWY_ASSERT_EQ(expected, MaskedReduceSum(d, mask, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedReduceSum() {
ForAllTypes(ForPartialVectors<TestMaskedReduceSum>());
}

struct TestMaskedReduceMin {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

using TI = MakeSigned<T>;
const Rebind<TI, D> di;
const Vec<D> v2 = Iota(d, 2);

const size_t N = Lanes(d);
auto bool_lanes = AllocateAligned<TI>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
T expected =
ConvertScalarTo<T>(N + 3); // larger than any values in the vector
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
if (bool_lanes[i]) {
if (expected > ConvertScalarTo<T>(i + 2)) {
expected = ConvertScalarTo<T>(i + 2);
}
}
}

const auto mask_i = Load(di, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));

// If all elements are disabled the result is implementation defined
if (AllFalse(d, mask)) {
continue;
}

HWY_ASSERT_EQ(expected, MaskedReduceMin(d, mask, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedReduceMin() {
ForAllTypes(ForPartialVectors<TestMaskedReduceMin>());
}

struct TestMaskedReduceMax {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

using TI = MakeSigned<T>;
const Rebind<TI, D> di;
const Vec<D> v2 = Iota(d, 2);

const size_t N = Lanes(d);
auto bool_lanes = AllocateAligned<TI>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
T expected = 0;
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
if (bool_lanes[i]) {
if (expected < ConvertScalarTo<T>(i + 2)) {
expected = ConvertScalarTo<T>(i + 2);
}
}
}

const auto mask_i = Load(di, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));

// If all elements are disabled the result is implementation defined
if (AllFalse(d, mask)) {
continue;
}

HWY_ASSERT_EQ(expected, MaskedReduceMax(d, mask, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedReduceMax() {
ForAllTypes(ForPartialVectors<TestMaskedReduceMax>());
}

} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
Expand All @@ -367,6 +489,10 @@ HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMinMaxOfLanes);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf2);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf4);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf8);

HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceSum);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceMin);
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceMax);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down

0 comments on commit 326dba3

Please sign in to comment.