Skip to content

Commit 02093da

Browse files
henrytwopytorchmergebot
authored andcommitted
Autogen native_batch_norm and native_batch_norm_backward (pytorch#79637)
This PR makes the `native_batch_norm` and `native_batch_norm_backward` ops autogen, and implements their respective shape inference functions. Previously, these two ops were manually implemented. cc: @ke1337 @antoniojkim @wconstab @desertfire Pull Request resolved: pytorch#79637 Approved by: https://github.com/Gamrix, https://github.com/desertfire
1 parent fdd3e20 commit 02093da

File tree

12 files changed

+78
-516
lines changed

12 files changed

+78
-516
lines changed

aten/src/ATen/native/ts_native_functions.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ full_codegen:
8585
- mm
8686
- mul.Tensor
8787
- mv
88+
- native_batch_norm
89+
- native_batch_norm_backward
8890
- native_dropout
8991
- native_dropout_backward
9092
- native_layer_norm
@@ -153,8 +155,6 @@ supported:
153155
- expand
154156
- fill_.Scalar
155157
- narrow
156-
- native_batch_norm
157-
- native_batch_norm_backward
158158
- normal_
159159
- max_pool3d_with_indices
160160
- max_pool3d_with_indices_backward

build_variables.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,6 @@ lazy_tensor_core_sources = [
420420
lazy_tensor_ts_sources = [
421421
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
422422
"torch/csrc/lazy/ts_backend/config.cpp",
423-
"torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp",
424423
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
425424
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
426425
"torch/csrc/lazy/ts_backend/ops/generic.cpp",

test/lazy/test_reuse_ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def testBatchNorm(self):
121121
torch._lazy.mark_step()
122122

123123
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
124-
assert metrics.counter_value("IrNodeReused_torch::lazy::TSNativeBatchNormForward") >= 7
124+
assert metrics.counter_value("IrNodeReused_torch::lazy::NativeBatchNorm") >= 7
125125
metrics.reset()
126126
torch._lazy.ir_cache.reset()
127127

torch/csrc/lazy/core/ir.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,11 @@ Shape Node::computeShape(const std::function<Shape()>& shape_fn) {
140140
const std::vector<Output>& Node::operands() const {
141141
return operands_as_outputs_;
142142
}
143+
143144
const Output& Node::operand(size_t i) const {
144145
return operands_as_outputs_.at(i);
145146
}
147+
146148
const Output& Node::nullable_operand(size_t i) const {
147149
// We use kNullOutput instead of kNullValue here to avoid implicit casting,
148150
// which would prevent this method from returning a reference.

torch/csrc/lazy/core/shape_inference.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,77 @@ std::vector<Shape> compute_shape_cat(at::TensorList tensors, int64_t dim) {
470470
return {Shape(tensors[0].scalar_type(), out_shape)};
471471
}
472472

473+
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
474+
const at::Tensor& input,
475+
const c10::optional<at::Tensor>& weight,
476+
const c10::optional<at::Tensor>& bias,
477+
const c10::optional<at::Tensor>& running_mean,
478+
const c10::optional<at::Tensor>& running_var,
479+
bool training,
480+
double momentum,
481+
double eps) {
482+
std::vector<torch::lazy::Shape> shapes;
483+
shapes.reserve(3);
484+
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
485+
486+
// A separate mean and var needs to be kept for each channel.
487+
TORCH_CHECK(
488+
input.sizes().size() >= 2,
489+
"Input tensor must have at least batch and channel dimensions!");
490+
int64_t num_features = input.size(1);
491+
492+
if (running_mean.has_value()) {
493+
shapes.emplace_back(
494+
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
495+
} else {
496+
shapes.emplace_back(
497+
at::get_default_dtype_as_scalartype(),
498+
std::vector<int64_t>{num_features});
499+
}
500+
501+
if (running_var.has_value()) {
502+
shapes.emplace_back(
503+
running_var.value().scalar_type(), running_var.value().sizes().vec());
504+
} else {
505+
shapes.emplace_back(
506+
at::get_default_dtype_as_scalartype(),
507+
std::vector<int64_t>{num_features});
508+
}
509+
return shapes;
510+
}
511+
512+
std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
513+
const at::Tensor& grad_out,
514+
const at::Tensor& input,
515+
const c10::optional<at::Tensor>& weight,
516+
const c10::optional<at::Tensor>& running_mean,
517+
const c10::optional<at::Tensor>& running_var,
518+
const c10::optional<at::Tensor>& save_mean,
519+
const c10::optional<at::Tensor>& save_invstd,
520+
bool train,
521+
double eps,
522+
::std::array<bool, 3> output_mask) {
523+
std::vector<torch::lazy::Shape> shapes;
524+
shapes.reserve(3);
525+
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
526+
527+
// A separate mean and var needs to be kept for each channel.
528+
TORCH_CHECK(
529+
input.sizes().size() >= 2,
530+
"Input tensor must have at least batch and channel dimensions!");
531+
int64_t num_features = input.size(1);
532+
533+
// `weight` and `bias` are vectors of length C (number of channels)`
534+
shapes.emplace_back(
535+
at::get_default_dtype_as_scalartype(),
536+
std::vector<int64_t>{num_features});
537+
shapes.emplace_back(
538+
at::get_default_dtype_as_scalartype(),
539+
std::vector<int64_t>{num_features});
540+
541+
return shapes;
542+
}
543+
473544
std::vector<Shape> compute_shape_native_layer_norm(
474545
const at::Tensor& input,
475546
at::IntArrayRef normalized_shape,

torch/csrc/lazy/core/shape_inference.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_max(const at::Tensor & s
5050
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mean(const at::Tensor & self, c10::optional<at::ScalarType> dtype);
5151
TORCH_API std::vector<torch::lazy::Shape> compute_shape_min(const at::Tensor & self);
5252
TORCH_API std::vector<torch::lazy::Shape> compute_shape_mv(const at::Tensor & self, const at::Tensor & vec);
53+
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);
54+
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, const c10::optional<at::Tensor> & save_mean, const c10::optional<at::Tensor> & save_invstd, bool train, double eps, ::std::array<bool,3> output_mask);
5355
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional<bool> train);
5456
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale);
5557
TORCH_API std::vector<torch::lazy::Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);

torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp

Lines changed: 0 additions & 97 deletions
This file was deleted.

torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h

Lines changed: 0 additions & 156 deletions
This file was deleted.

0 commit comments

Comments
 (0)