@@ -23,6 +23,37 @@ namespace at::native {
23
23
24
24
namespace {
25
25
26
+ template <typename T>
27
+ void LayerNormSecondPass (
28
+ const T* X_ptr,
29
+ const T* gamma_data,
30
+ const T* beta_data,
31
+ T* Y_ptr,
32
+ int64_t N,
33
+ T scale,
34
+ T bias) {
35
+ using Vec = vec::Vectorized<T>;
36
+ const bool gamma_null = gamma_data == nullptr ;
37
+ const bool beta_null = beta_data == nullptr ;
38
+ if (gamma_null || beta_null) {
39
+ for (const auto j : c10::irange (N)) {
40
+ const T gamma_v = gamma_null ? T (1 ) : gamma_data[j];
41
+ const T beta_v = beta_null ? T (0 ) : beta_data[j];
42
+ Y_ptr[j] = (X_ptr[j] + bias) * scale * gamma_v + beta_v;
43
+ }
44
+ } else {
45
+ vec::map3<T>(
46
+ [scale, bias](Vec x, Vec gamma, Vec beta) {
47
+ return (x + Vec (bias)) * Vec (scale) * gamma + beta;
48
+ },
49
+ Y_ptr,
50
+ X_ptr,
51
+ gamma_data,
52
+ beta_data,
53
+ N);
54
+ }
55
+ }
56
+
26
57
template <typename T,
27
58
typename std::enable_if_t <!is_reduced_floating_point_v<T>, int > = 0 >
28
59
void LayerNormKernelImplInternal (
@@ -35,16 +66,13 @@ void LayerNormKernelImplInternal(
35
66
Tensor* Y,
36
67
Tensor* mean,
37
68
Tensor* rstd) {
38
- using Vec = vec::Vectorized<T>;
39
69
const T* X_data = X.const_data_ptr <T>();
40
70
const T* gamma_data = gamma.defined () ? gamma.const_data_ptr <T>() : nullptr ;
41
71
const T* beta_data = beta.defined () ? beta.const_data_ptr <T>() : nullptr ;
42
72
T* Y_data = Y->data_ptr <T>();
43
73
T* mean_data = mean ? mean->data_ptr <T>() : nullptr ;
44
74
T* rstd_data = rstd ? rstd->data_ptr <T>() : nullptr ;
45
75
46
- const bool gamma_null = gamma_data == nullptr ;
47
- const bool beta_null = beta_data == nullptr ;
48
76
const bool mean_null = mean_data == nullptr ;
49
77
const bool rstd_null = rstd_data == nullptr ;
50
78
at::parallel_for (0 , M, 1 , [&](int64_t start, int64_t end) {
@@ -55,23 +83,7 @@ void LayerNormKernelImplInternal(
55
83
rstd_val = T (1 ) / std::sqrt (rstd_val + eps);
56
84
const T scale = rstd_val;
57
85
const T bias = - mean_val;
58
- if (gamma_null || beta_null) {
59
- for (const auto j : c10::irange (N)) {
60
- const T gamma_v = gamma_null ? T (1 ) : gamma_data[j];
61
- const T beta_v = beta_null ? T (0 ) : beta_data[j];
62
- Y_ptr[j] = (X_ptr[j] + bias) * rstd_val * gamma_v + beta_v;
63
- }
64
- } else {
65
- vec::map3<T>(
66
- [scale, bias](Vec x, Vec gamma, Vec beta) {
67
- return (x + Vec (bias)) * Vec (scale) * gamma + beta;
68
- },
69
- Y_ptr,
70
- X_ptr,
71
- gamma_data,
72
- beta_data,
73
- N);
74
- }
86
+ LayerNormSecondPass<T>(X_ptr, gamma_data, beta_data, Y_ptr, N, scale, bias);
75
87
if (!mean_null) {
76
88
mean_data[i] = mean_val;
77
89
}
@@ -191,6 +203,14 @@ void layer_norm_backward_frame(
191
203
T* dX_data,
192
204
T* dgamma_buffer_ptr,
193
205
T* dbeta_buffer_ptr,
206
+ // NOTE: the below @lint-ignore is only necessary because we compile
207
+ // specializations of this function for c10::complex.
208
+ // It's extremely likely that nobody actually takes layer norms of
209
+ // complex tensors, and even if they are, c10::complex is laid out poorly
210
+ // and basically should never be used.
211
+ // So it would be nice in the future to figure out how to stop compiling
212
+ // specializations of compute kernels for c10::complex.
213
+ // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
194
214
const opmath_t scale,
195
215
const bool gamma_null,
196
216
const bool dX_null,
@@ -481,7 +501,7 @@ void layer_norm_backward_frame(
481
501
fVec r_fvec0 = fVec (a) * dy_fvec0 * gamma_fvec0 + fVec (b) * x_fvec0 + fVec (c);
482
502
fVec r_fvec1 = fVec (a) * dy_fvec1 * gamma_fvec1 + fVec (b) * x_fvec1 + fVec (c);
483
503
bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
484
- r_bvec.store (dX_ptr + d, N - d);
504
+ r_bvec.store (dX_ptr + d, static_cast < int >( N - d) );
485
505
}
486
506
}
487
507
}
0 commit comments