Skip to content

Commit 741539a

Browse files
sharpobjectpytorchmergebot
authored andcommitted
Split out second pass of LayerNorm for profiler attribution reasons (pytorch#153578)
Summary: Split out second pass of LayerNorm so it's more likely to show up in profiler output. In my testing with perf, the samples from the lambda in the current implementation are attributed somewhat haphazardly. Differential Revision: D74181627 Pull Request resolved: pytorch#153578 Approved by: https://github.com/hl475
1 parent a9adc9a commit 741539a

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

aten/src/ATen/OpMathType.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ struct OpMathType<at::Float8_e4m3fnuz> {
4141
using type = float;
4242
};
4343
template <>
44+
struct OpMathType<at::Float8_e8m0fnu> {
45+
using type = float;
46+
};
47+
template <>
4448
struct OpMathType<c10::complex<Half>> {
4549
using type = c10::complex<float>;
4650
};

aten/src/ATen/native/cpu/layer_norm_kernel.cpp

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,37 @@ namespace at::native {
2323

2424
namespace {
2525

26+
template <typename T>
27+
void LayerNormSecondPass(
28+
const T* X_ptr,
29+
const T* gamma_data,
30+
const T* beta_data,
31+
T* Y_ptr,
32+
int64_t N,
33+
T scale,
34+
T bias) {
35+
using Vec = vec::Vectorized<T>;
36+
const bool gamma_null = gamma_data == nullptr;
37+
const bool beta_null = beta_data == nullptr;
38+
if (gamma_null || beta_null) {
39+
for (const auto j : c10::irange(N)) {
40+
const T gamma_v = gamma_null ? T(1) : gamma_data[j];
41+
const T beta_v = beta_null ? T(0) : beta_data[j];
42+
Y_ptr[j] = (X_ptr[j] + bias) * scale * gamma_v + beta_v;
43+
}
44+
} else {
45+
vec::map3<T>(
46+
[scale, bias](Vec x, Vec gamma, Vec beta) {
47+
return (x + Vec(bias)) * Vec(scale) * gamma + beta;
48+
},
49+
Y_ptr,
50+
X_ptr,
51+
gamma_data,
52+
beta_data,
53+
N);
54+
}
55+
}
56+
2657
template <typename T,
2758
typename std::enable_if_t<!is_reduced_floating_point_v<T>, int> = 0>
2859
void LayerNormKernelImplInternal(
@@ -35,16 +66,13 @@ void LayerNormKernelImplInternal(
3566
Tensor* Y,
3667
Tensor* mean,
3768
Tensor* rstd) {
38-
using Vec = vec::Vectorized<T>;
3969
const T* X_data = X.const_data_ptr<T>();
4070
const T* gamma_data = gamma.defined() ? gamma.const_data_ptr<T>() : nullptr;
4171
const T* beta_data = beta.defined() ? beta.const_data_ptr<T>() : nullptr;
4272
T* Y_data = Y->data_ptr<T>();
4373
T* mean_data = mean ? mean->data_ptr<T>() : nullptr;
4474
T* rstd_data = rstd ? rstd->data_ptr<T>() : nullptr;
4575

46-
const bool gamma_null = gamma_data == nullptr;
47-
const bool beta_null = beta_data == nullptr;
4876
const bool mean_null = mean_data == nullptr;
4977
const bool rstd_null = rstd_data == nullptr;
5078
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
@@ -55,23 +83,7 @@ void LayerNormKernelImplInternal(
5583
rstd_val = T(1) / std::sqrt(rstd_val + eps);
5684
const T scale = rstd_val;
5785
const T bias = - mean_val;
58-
if (gamma_null || beta_null) {
59-
for (const auto j : c10::irange(N)) {
60-
const T gamma_v = gamma_null ? T(1) : gamma_data[j];
61-
const T beta_v = beta_null ? T(0) : beta_data[j];
62-
Y_ptr[j] = (X_ptr[j] + bias) * rstd_val * gamma_v + beta_v;
63-
}
64-
} else {
65-
vec::map3<T>(
66-
[scale, bias](Vec x, Vec gamma, Vec beta) {
67-
return (x + Vec(bias)) * Vec(scale) * gamma + beta;
68-
},
69-
Y_ptr,
70-
X_ptr,
71-
gamma_data,
72-
beta_data,
73-
N);
74-
}
86+
LayerNormSecondPass<T>(X_ptr, gamma_data, beta_data, Y_ptr, N, scale, bias);
7587
if (!mean_null) {
7688
mean_data[i] = mean_val;
7789
}
@@ -191,6 +203,14 @@ void layer_norm_backward_frame(
191203
T* dX_data,
192204
T* dgamma_buffer_ptr,
193205
T* dbeta_buffer_ptr,
206+
// NOTE: the below @lint-ignore is only necessary because we compile
207+
// specializations of this function for c10::complex.
208+
// It's extremely likely that nobody actually takes layer norms of
209+
// complex tensors, and even if they are, c10::complex is laid out poorly
210+
// and basically should never be used.
211+
// So it would be nice in the future to figure out how to stop compiling
212+
// specializations of compute kernels for c10::complex.
213+
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
194214
const opmath_t scale,
195215
const bool gamma_null,
196216
const bool dX_null,
@@ -481,7 +501,7 @@ void layer_norm_backward_frame(
481501
fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
482502
fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
483503
bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
484-
r_bvec.store(dX_ptr + d, N - d);
504+
r_bvec.store(dX_ptr + d, static_cast<int>(N - d));
485505
}
486506
}
487507
}

0 commit comments

Comments
 (0)