@@ -198,16 +198,16 @@ HWY_INLINE void MatVec(const CompressedArray<MatT, kCapacity>& mat,
198
198
199
199
template <class D , HWY_IF_F32_D(D)>
200
200
static HWY_INLINE hn::Vec<D> Gelu (D d, hn::Vec<D> v) {
201
- const hn::Vec<D> kMul = Set (d, 0 .044715f );
201
+ const hn::Vec<D> kMul = hn:: Set (d, 0 .044715f );
202
202
const hn::Vec<D> kSqrt2OverPi = hn::Set (d, 0 .797884560804236f );
203
- const hn::Vec<D> kHalf = Set (d, 0 .5f );
203
+ const hn::Vec<D> kHalf = hn:: Set (d, 0 .5f );
204
204
205
205
// tanh approximation matches training.
206
206
const hn::Vec<D> v3 = hn::Mul (hn::Mul (v, v), v);
207
207
const hn::Vec<D> arg = hn::Mul (kSqrt2OverPi , hn::MulAdd (kMul , v3, v));
208
208
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
209
209
const hn::Vec<D> cdf = hn::MulAdd (kHalf , hn::Tanh (d, arg), kHalf );
210
- return Mul (v, cdf);
210
+ return hn:: Mul (v, cdf);
211
211
}
212
212
213
213
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu (float * HWY_RESTRICT x,
@@ -230,21 +230,22 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
230
230
size_t i = 0 ;
231
231
if (size >= 2 * NF) {
232
232
for (; i < size - 2 * NF; i += 2 * NF) {
233
- const VF mul0 = LoadU (df, mul + i);
234
- const VF mul1 = LoadU (df, mul + i + NF);
235
- const VF g0 = Mul (mul0, Gelu (df, LoadU (df, gelu_in + i)));
236
- const VF g1 = Mul (mul1, Gelu (df, LoadU (df, gelu_in + i + NF)));
233
+ const VF mul0 = hn:: LoadU (df, mul + i);
234
+ const VF mul1 = hn:: LoadU (df, mul + i + NF);
235
+ const VF g0 = hn:: Mul (mul0, Gelu (df, hn:: LoadU (df, gelu_in + i)));
236
+ const VF g1 = hn:: Mul (mul1, Gelu (df, hn:: LoadU (df, gelu_in + i + NF)));
237
237
const hn::Vec<decltype (dbf)> bf = hn::OrderedDemote2To (dbf, g0, g1);
238
- StoreU (bf, dbf, out + i);
238
+ hn:: StoreU (bf, dbf, out + i);
239
239
}
240
240
}
241
241
if (i != size) {
242
242
const size_t remaining = size - i;
243
- const VF mul0 = LoadN (df, mul + i, remaining);
244
- const VF g0 = Mul (mul0, Gelu (df, LoadN (df, gelu_in + i, remaining)));
243
+ const VF mul0 = hn::LoadN (df, mul + i, remaining);
244
+ const VF g0 =
245
+ hn::Mul (mul0, Gelu (df, hn::LoadN (df, gelu_in + i, remaining)));
245
246
const hn::Half<decltype (dbf)> dbfh;
246
247
const hn::Vec<decltype (dbfh)> bfh = hn::DemoteTo (dbfh, g0);
247
- StoreN (bfh, dbfh, out + i, remaining);
248
+ hn:: StoreN (bfh, dbfh, out + i, remaining);
248
249
}
249
250
}
250
251
@@ -381,7 +382,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
381
382
382
383
constexpr float eps = 1e-6f ;
383
384
const float ss = SquaredL2 (inout, size);
384
- const VF vss = Set (df32, 1 .0f / sqrtf (ss / static_cast <int >(size) + eps));
385
+ const VF vss = hn:: Set (df32, 1 .0f / sqrtf (ss / static_cast <int >(size) + eps));
385
386
386
387
HWY_DASSERT (size % (2 * MaxLanes (df32)) == 0 );
387
388
for (size_t i = 0 ; i < size; i += 2 * N32) {
@@ -409,7 +410,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
409
410
410
411
constexpr float eps = 1e-6f ;
411
412
const float ss = SquaredL2 (x, size);
412
- const VF vss = Set (df32, 1 .0f / sqrtf (ss / static_cast <int >(size) + eps));
413
+ const VF vss = hn:: Set (df32, 1 .0f / sqrtf (ss / static_cast <int >(size) + eps));
413
414
414
415
HWY_DASSERT (size % (2 * MaxLanes (df32)) == 0 );
415
416
for (size_t i = 0 ; i < size; i += 2 * N32) {
@@ -436,7 +437,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
436
437
437
438
constexpr float eps = 1e-6f ;
438
439
const float ss = SquaredL2 (x, size);
439
- const VF vss = Set (df32, 1 .0f / sqrtf (ss / size + eps));
440
+ const VF vss = hn:: Set (df32, 1 .0f / sqrtf (ss / size + eps));
440
441
441
442
HWY_DASSERT (size % (2 * MaxLanes (df32)) == 0 );
442
443
for (size_t i = 0 ; i < size; i += 2 * N32) {
0 commit comments