Skip to content

Commit 4a0d23f

Browse files
jan-wassenbergdan-zheng
authored andcommitted
Add missing hn::, fixes google#25
PiperOrigin-RevId: 609914890
1 parent af715d2 commit 4a0d23f

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

ops.h

+15-14
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,16 @@ HWY_INLINE void MatVec(const CompressedArray<MatT, kCapacity>& mat,
198198

199199
template <class D, HWY_IF_F32_D(D)>
200200
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
201-
const hn::Vec<D> kMul = Set(d, 0.044715f);
201+
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
202202
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
203-
const hn::Vec<D> kHalf = Set(d, 0.5f);
203+
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
204204

205205
// tanh approximation matches training.
206206
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v);
207207
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
208208
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
209209
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
210-
return Mul(v, cdf);
210+
return hn::Mul(v, cdf);
211211
}
212212

213213
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
@@ -230,21 +230,22 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
230230
size_t i = 0;
231231
if (size >= 2 * NF) {
232232
for (; i < size - 2 * NF; i += 2 * NF) {
233-
const VF mul0 = LoadU(df, mul + i);
234-
const VF mul1 = LoadU(df, mul + i + NF);
235-
const VF g0 = Mul(mul0, Gelu(df, LoadU(df, gelu_in + i)));
236-
const VF g1 = Mul(mul1, Gelu(df, LoadU(df, gelu_in + i + NF)));
233+
const VF mul0 = hn::LoadU(df, mul + i);
234+
const VF mul1 = hn::LoadU(df, mul + i + NF);
235+
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
236+
const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF)));
237237
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
238-
StoreU(bf, dbf, out + i);
238+
hn::StoreU(bf, dbf, out + i);
239239
}
240240
}
241241
if (i != size) {
242242
const size_t remaining = size - i;
243-
const VF mul0 = LoadN(df, mul + i, remaining);
244-
const VF g0 = Mul(mul0, Gelu(df, LoadN(df, gelu_in + i, remaining)));
243+
const VF mul0 = hn::LoadN(df, mul + i, remaining);
244+
const VF g0 =
245+
hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining)));
245246
const hn::Half<decltype(dbf)> dbfh;
246247
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
247-
StoreN(bfh, dbfh, out + i, remaining);
248+
hn::StoreN(bfh, dbfh, out + i, remaining);
248249
}
249250
}
250251

@@ -381,7 +382,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
381382

382383
constexpr float eps = 1e-6f;
383384
const float ss = SquaredL2(inout, size);
384-
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
385+
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
385386

386387
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
387388
for (size_t i = 0; i < size; i += 2 * N32) {
@@ -409,7 +410,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
409410

410411
constexpr float eps = 1e-6f;
411412
const float ss = SquaredL2(x, size);
412-
const VF vss = Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
413+
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
413414

414415
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
415416
for (size_t i = 0; i < size; i += 2 * N32) {
@@ -436,7 +437,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
436437

437438
constexpr float eps = 1e-6f;
438439
const float ss = SquaredL2(x, size);
439-
const VF vss = Set(df32, 1.0f / sqrtf(ss / size + eps));
440+
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps));
440441

441442
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
442443
for (size_t i = 0; i < size; i += 2 * N32) {

0 commit comments

Comments
 (0)