Skip to content

Commit

Permalink
Merge pull request #2423 from cambridgeconsultants:cc_up_complex_arit…
Browse files Browse the repository at this point in the history
…hmetic

PiperOrigin-RevId: 723985611
  • Loading branch information
copybara-github committed Feb 6, 2025
2 parents 87e25b0 + 86e5d17 commit 8f7b5d4
Show file tree
Hide file tree
Showing 6 changed files with 580 additions and 0 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ HWY_TESTS = [
("hwy/tests/", "combine_test"),
("hwy/tests/", "compare_test"),
("hwy/tests/", "compress_test"),
("hwy/tests/", "complex_arithmetic_test"),
("hwy/tests/", "concat_test"),
("hwy/tests/", "convert_test"),
("hwy/tests/", "count_test"),
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ set(HWY_TEST_FILES
hwy/tests/cast_test.cc
hwy/tests/combine_test.cc
hwy/tests/compare_test.cc
hwy/tests/complex_arithmetic_test.cc
hwy/tests/compress_test.cc
hwy/tests/concat_test.cc
hwy/tests/convert_test.cc
Expand Down
35 changes: 35 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,41 @@ to, and potentially more efficient than, `IfThenElseZero(m, Add(a, b));` etc.
<code>V **MaskedApproximateReciprocal**(M m, V a)</code>: returns the
result of ApproximateReciprocal where m is true and zero otherwise.

#### Complex number operations

Complex types are represented as complex value pairs of real and imaginary
components, with the real components in even-indexed lanes and the imaginary
components in odd-indexed lanes.

All multiplies in this section are performing complex multiplication,
i.e. `(a + ib)(c + id)`.

Take `j` to be the even values of `i`.

* `V`: `{f}` \
<code>V **ComplexConj**(V v)</code>: returns the complex conjugate of the vector,
this negates the imaginary lanes. This is equivalent to `OddEven(Neg(a), a)`.
* `V`: `{f}` \
<code>V **MulComplex**(V a, V b)</code>: returns `(a[j] + i.a[j + 1])(b[j] + i.b[j + 1])`
* `V`: `{f}` \
<code>V **MulComplexConj**(V a, V b)</code>: returns `(a[j] + i.a[j + 1])(b[j] - i.b[j + 1])`
* `V`: `{f}` \
<code>V **MulComplexAdd**(V a, V b, V c)</code>: returns
`(a[j] + i.a[j + 1])(b[j] + i.b[j + 1]) + (c[j] + i.c[j + 1])`
* `V`: `{f}` \
<code>V **MulComplexConjAdd**(V a, V b, V c)</code>: returns
`(a[j] + i.a[j + 1])(b[j] - i.b[j + 1]) + (c[j] + i.c[j + 1])`
* `V`: `{f}` \
<code>V **MaskedMulComplexConjAdd**(M mask, V a, V b, V c)</code>: returns
`(a[j] + i.a[j + 1])(b[j] - i.b[j + 1]) + (c[j] + i.c[j + 1])` or `0` if
`mask[i]` is false.
* `V`: `{f}` \
<code>V **MaskedMulComplexConj**(M mask, V a, V b)</code>: returns
`(a[j] + i.a[j + 1])(b[j] - i.b[j + 1])` or `0` if `mask[i]` is false.
* `V`: `{f}` \
<code>V **MaskedMulComplexOr**(V no, M mask, V a, V b)</code>: returns `(a[j] +
i.a[j + 1])(b[j] + i.b[j + 1])` or `no[i]` if `mask[i]` is false.

#### Shifts

**Note**: Counts not in `[0, sizeof(T)*8)` yield implementation-defined results.
Expand Down
124 changes: 124 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6452,6 +6452,130 @@ HWY_API VFromD<DU64> SumOfMulQuadAccumulate(DU64 /*du64*/, svuint16_t a,
return svdot_u64(sum, a, b);
}

// ------------------------------ MulComplex* / MaskedMulComplex*

// Per-target flag to prevent generic_ops-inl.h from defining MulComplex*.
#ifdef HWY_NATIVE_CPLX
#undef HWY_NATIVE_CPLX
#else
#define HWY_NATIVE_CPLX
#endif

template <class V, HWY_IF_NOT_UNSIGNED(TFromV<V>)>
HWY_API V ComplexConj(V a) {
return OddEven(Neg(a), a);
}

namespace detail {
#define HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, ROT) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME##ROT(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b, c, ROT); \
} \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME##Z##ROT(svbool_t m, HWY_SVE_V(BASE, BITS) a, \
HWY_SVE_V(BASE, BITS) b, HWY_SVE_V(BASE, BITS) c) { \
return sv##OP##_##CHAR##BITS##_z(m, a, b, c, ROT); \
}

#define HWY_SVE_CPLX_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 0) \
HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 90) \
HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 180) \
HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 270)

// Only SVE2 has complex multiply add for integer types
// and these do not include masked variants
HWY_SVE_FOREACH_F(HWY_SVE_CPLX_FMA, ComplexMulAdd, cmla)
#undef HWY_SVE_CPLX_FMA
#undef HWY_SVE_CPLX_FMA_ROT
} // namespace detail

template <class V, class M, HWY_IF_FLOAT_V(V)>
HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) {
return detail::ComplexMulAddZ270(mask, detail::ComplexMulAddZ0(mask, c, b, a), b,
a);
}

template <class V, class M, HWY_IF_FLOAT_V(V)>
HWY_API V MaskedMulComplexConj(M mask, V a, V b) {
return MaskedMulComplexConjAdd(mask, a, b, Zero(DFromV<V>()));
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulComplexAdd(V a, V b, V c) {
return detail::ComplexMulAdd90(detail::ComplexMulAdd0(c, a, b), a, b);
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulComplex(V a, V b) {
return MulComplexAdd(a, b, Zero(DFromV<V>()));
}

template <class V, class M, HWY_IF_FLOAT_V(V)>
HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) {
return IfThenElse(mask, MulComplex(a, b), no);
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulComplexConjAdd(V a, V b, V c) {
return detail::ComplexMulAdd270(detail::ComplexMulAdd0(c, b, a), b, a);
}

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulComplexConj(V a, V b) {
return MulComplexConjAdd(a, b, Zero(DFromV<V>()));
}

// TODO SVE2 does have intrinsics for integers but not masked variants
template <class V, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MulComplex(V a, V b) {
// a = u + iv, b = x + iy
const auto u = DupEven(a);
const auto v = DupOdd(a);
const auto x = DupEven(b);
const auto y = DupOdd(b);

return OddEven(MulAdd(u, y, Mul(v, x)), Sub(Mul(u, x), Mul(v, y)));
}

template <class V, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MulComplexConj(V a, V b) {
// a = u + iv, b = x + iy
const auto u = DupEven(a);
const auto v = DupOdd(a);
const auto x = DupEven(b);
const auto y = DupOdd(b);

return OddEven(Sub(Mul(v, x), Mul(u, y)), MulAdd(u, x, Mul(v, y)));
}

template <class V, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MulComplexAdd(V a, V b, V c) {
return Add(MulComplex(a, b), c);
}

template <class V, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MulComplexConjAdd(V a, V b, V c) {
return Add(MulComplexConj(a, b), c);
}

template <class V, class M, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) {
return IfThenElseZero(mask, MulComplexConjAdd(a, b, c));
}

template <class V, class M, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MaskedMulComplexConj(M mask, V a, V b) {
return IfThenElseZero(mask, MulComplexConj(a, b));
}

template <class V, class M, HWY_IF_NOT_FLOAT_V(V)>
HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) {
return IfThenElse(mask, MulComplex(a, b), no);
}

// ------------------------------ AESRound / CLMul

// Static dispatch with -march=armv8-a+sve2+aes, or dynamic dispatch WITHOUT a
Expand Down
65 changes: 65 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4626,6 +4626,71 @@ HWY_API V MulSub(V mul, V x, V sub) {
return Sub(Mul(mul, x), sub);
}
#endif // HWY_NATIVE_INT_FMA
// ------------------------------ MulComplex* / MaskedMulComplex*

#if (defined(HWY_NATIVE_CPLX) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_CPLX
#undef HWY_NATIVE_CPLX
#else
#define HWY_NATIVE_CPLX
#endif

#if HWY_TARGET != HWY_SCALAR || HWY_IDE

template <class V, HWY_IF_NOT_UNSIGNED(TFromV<V>)>
HWY_API V ComplexConj(V a) {
return OddEven(Neg(a), a);
}

template <class V>
HWY_API V MulComplex(V a, V b) {
// a = u + iv, b = x + iy
const auto u = DupEven(a);
const auto v = DupOdd(a);
const auto x = DupEven(b);
const auto y = DupOdd(b);

return OddEven(MulAdd(u, y, Mul(v, x)), Sub(Mul(u, x), Mul(v, y)));
}

template <class V>
HWY_API V MulComplexConj(V a, V b) {
// a = u + iv, b = x + iy
const auto u = DupEven(a);
const auto v = DupOdd(a);
const auto x = DupEven(b);
const auto y = DupOdd(b);

return OddEven(Sub(Mul(v, x), Mul(u, y)), MulAdd(u, x, Mul(v, y)));
}

template <class V>
HWY_API V MulComplexAdd(V a, V b, V c) {
return Add(MulComplex(a, b), c);
}

template <class V>
HWY_API V MulComplexConjAdd(V a, V b, V c) {
return Add(MulComplexConj(a, b), c);
}

template <class V, class M>
HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) {
return IfThenElseZero(mask, MulComplexConjAdd(a, b, c));
}

template <class V, class M>
HWY_API V MaskedMulComplexConj(M mask, V a, V b) {
return IfThenElseZero(mask, MulComplexConj(a, b));
}

template <class V, class M>
HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) {
return IfThenElse(mask, MulComplex(a, b), no);
}
#endif // HWY_TARGET != HWY_SCALAR

#endif // HWY_NATIVE_CPLX

// ------------------------------ MaskedMulAddOr
#if (defined(HWY_NATIVE_MASKED_INT_FMA) == defined(HWY_TARGET_TOGGLE))
Expand Down
Loading

0 comments on commit 8f7b5d4

Please sign in to comment.