Skip to content

Commit 34e8eb2

Browse files
lezcanopytorchmergebot
authored andcommitted
Dispatch the auxiliary frobenius_norm and nuclear_norm to better implementations and deprecate them (pytorch#81763)
These functions will be legacy functions. We deprecate them, but we also take this chance to dispatch to a more efficient and consistent implementation. Doing so should help writing a conversion rule for these to be able to remove them once and for all Differential Revision: [D42354776](https://our.internmc.facebook.com/intern/diff/D42354776) Pull Request resolved: pytorch#81763 Approved by: https://github.com/ngimel
1 parent 1af40d5 commit 34e8eb2

File tree

3 files changed

+62
-88
lines changed

3 files changed

+62
-88
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

+51-61
Original file line numberDiff line numberDiff line change
@@ -2771,94 +2771,84 @@ Tensor& linalg_norm_out(const Tensor& X, c10::string_view ord, OptionalIntArrayR
27712771

27722772
////////////////////////////////////////////////////////////////////////////////
27732773
// Frobenius Norm //
2774-
// Just used in torch..norm. It should not be removed. //
27752774
////////////////////////////////////////////////////////////////////////////////
27762775

27772776
Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
2778-
TORCH_CHECK(
2779-
dim.size() <= 2,
2780-
"Expected at most 2 dimensions, but got ",
2781-
dim.size(),
2782-
" dimensions instead.");
2783-
Tensor result;
2784-
if (dim.size() == 1 || dim.size() == 0) {
2785-
result = at::norm(self, 2, dim, keepdim);
2786-
} else {
2787-
auto dim_ = dim.vec();
2788-
maybe_wrap_dims(dim_, self.dim());
2789-
TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
2790-
if (self.is_complex()) {
2791-
result = at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim));
2792-
} else {
2793-
result = at::sqrt(at::sum((self * self), dim_, keepdim));
2794-
}
2777+
auto device = self.device();
2778+
if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
2779+
TORCH_WARN_ONCE(
2780+
"at::frobenius_norm is deprecated and it is just left for JIT compatibility. ",
2781+
"It will be removed in a future PyTorch release. Please use ",
2782+
"`linalg.vector_norm(A, 2., dim, keepdim)` instead"
2783+
);
27952784
}
2796-
TORCH_INTERNAL_ASSERT(result.scalar_type() == toRealValueType(self.scalar_type()));
2797-
TORCH_INTERNAL_ASSERT(result.layout() == c10::Layout::Strided);
2798-
return result;
2785+
// This frobenius norm is just wrong, but well
2786+
TORCH_CHECK(dim.size() <= 2,
2787+
"Expected at most 2 dimensions, but got ", dim.size(), " dimensions instead.");
2788+
// Dispatch to at::norm as it is implemented for Sparse and MPS backends
2789+
// TODO Make the backends implement vector_norm and matrix_norm
2790+
return at::norm(self, 2., dim, keepdim);
27992791
}
28002792

28012793
Tensor &frobenius_norm_out(const Tensor& self,
28022794
IntArrayRef dim,
28032795
bool keepdim,
28042796
Tensor& result) {
2805-
auto result_ = at::native::frobenius_norm(self, dim, keepdim);
2806-
// NOTE: It would be better to avoid resize and copy by using norm_out and sqrt_out in frobenius_norm.
2807-
// However, norm_out and sqrt_out do not support automatic differentiation.
2808-
// More details here: https://github.com/pytorch/pytorch/pull/44095#discussion_r486673947
2809-
at::native::resize_output(result, result_.sizes());
2810-
result.copy_(result_);
2811-
return result;
2797+
auto device = self.device();
2798+
if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
2799+
TORCH_WARN_ONCE(
2800+
"at::frobenius_norm is deprecated and it is just left for JIT compatibility. ",
2801+
"It will be removed in a future PyTorch release. Please use ",
2802+
"`linalg.vector_norm(A, 2., dim, keepdim)` instead"
2803+
);
2804+
}
2805+
TORCH_CHECK(dim.size() <= 2,
2806+
"Expected at most 2 dimensions, but got ", dim.size(), " dimensions instead.");
2807+
return at::norm_out(result, self, 2., dim, keepdim);
28122808
}
28132809

28142810
////////////////////////////////////////////////////////////////////////////////
28152811
// Nuclear Norm //
2816-
// Just used in torch.norm. It should not be removed. //
28172812
////////////////////////////////////////////////////////////////////////////////
28182813

28192814
Tensor nuclear_norm(const Tensor& self, bool keepdim) {
2820-
TORCH_CHECK(
2821-
self.dim() == 2,
2822-
"Expected a tensor with 2 dimensions, but got a tensor with ",
2823-
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
2824-
return at::native::nuclear_norm(self, IntArrayRef({0, 1}), keepdim);
2815+
return at::native::nuclear_norm(self, IntArrayRef({-2, -1}), keepdim);
28252816
}
28262817

28272818
Tensor &nuclear_norm_out(const Tensor& self, bool keepdim, Tensor& result) {
2828-
TORCH_CHECK(
2829-
self.dim() == 2,
2830-
"Expected a tensor with 2 dimensions, but got a tensor with ",
2831-
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
2832-
return at::native::nuclear_norm_out(self, IntArrayRef({0, 1}), keepdim, result);
2833-
}
2834-
2835-
namespace {
2836-
Tensor nuclear_norm_impl(const Tensor& self, IntArrayRef dim, bool keepdim) {
2837-
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
2838-
auto dim_ = dim.vec();
2839-
maybe_wrap_dims(dim_, self.dim());
2840-
2841-
auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
2842-
Tensor p = self.permute(permutation);
2843-
Tensor result_ = at::sum(at::linalg_svdvals(p), -1, keepdim);
2844-
if (keepdim) {
2845-
result_.unsqueeze_(-1);
2846-
auto permutation_reverse = create_reverse_permutation(std::move(permutation));
2847-
result_ = result_.permute(permutation_reverse);
2819+
auto device = self.device();
2820+
if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
2821+
TORCH_WARN_ONCE(
2822+
"at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
2823+
"It will be removed in a future PyTorch release. Please use ",
2824+
"`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
2825+
);
28482826
}
2849-
return result_;
2827+
return at::linalg_matrix_norm_out(result, self, "nuc", IntArrayRef({-2, -1}), keepdim);
28502828
}
2851-
} // anonymous namespace
28522829

28532830
Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
2854-
return nuclear_norm_impl(self, dim, keepdim).to(toRealValueType(self.scalar_type()));
2831+
auto device = self.device();
2832+
if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
2833+
TORCH_WARN_ONCE(
2834+
"at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
2835+
"It will be removed in a future PyTorch release. Please use ",
2836+
"`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
2837+
);
2838+
}
2839+
return at::linalg_matrix_norm(self, "nuc", dim, keepdim);
28552840
}
28562841

28572842
Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
2858-
auto result_ = nuclear_norm_impl(self, dim, keepdim);
2859-
at::native::resize_output(result, result_.sizes());
2860-
result.copy_(result_);
2861-
return result;
2843+
auto device = self.device();
2844+
if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
2845+
TORCH_WARN_ONCE(
2846+
"at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
2847+
"It will be removed in a future PyTorch release. Please use ",
2848+
"`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
2849+
);
2850+
}
2851+
return at::linalg_matrix_norm_out(result, self, "nuc", dim, keepdim);
28622852
}
28632853

28642854
////////////////////////////////////////////////////////////////////////////////

aten/src/ATen/native/ReduceOps.cpp

+4-27
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
#include <ATen/ops/gradient_native.h>
6767
#include <ATen/ops/imag.h>
6868
#include <ATen/ops/isnan_native.h>
69+
#include <ATen/ops/linalg_vector_norm.h>
6970
#include <ATen/ops/logcumsumexp.h>
7071
#include <ATen/ops/logcumsumexp_native.h>
7172
#include <ATen/ops/logical_xor.h>
@@ -1451,34 +1452,10 @@ void impl_func_norm(
14511452
bool keepdim,
14521453
optional<ScalarType> opt_dtype,
14531454
const Tensor& result) {
1455+
// Left this implementation without deprecating it as it is called in a number of places
1456+
// in the codebase. We should swap those by linalg_vector_norm
14541457
auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to<double>();
1455-
auto in_dtype = opt_dtype.value_or(self.scalar_type());
1456-
auto out_dtype = result.scalar_type();
1457-
1458-
// See the note [Reductions do not use vectorized ops]
1459-
Tensor self_;
1460-
if (self.is_cpu() && self.is_complex() && std::abs(p.toDouble()) == INFINITY) {
1461-
if (opt_dtype.has_value()) {
1462-
self_ = self.to(*opt_dtype).abs();
1463-
} else {
1464-
self_ = self.abs();
1465-
}
1466-
} else {
1467-
self_ = self;
1468-
}
1469-
1470-
1471-
// omit in_dtype in the following call, to avoid make_reduction explicitly
1472-
// casting input to out_dtype
1473-
auto iter = isComplexType(self_.scalar_type())
1474-
? meta::make_reduction(self_, result, dim, keepdim, in_dtype)
1475-
: meta::make_reduction_from_out_ty(self_, result, dim, keepdim, out_dtype);
1476-
1477-
if (iter.numel() == 0) {
1478-
result.zero_();
1479-
} else {
1480-
norm_stub(iter.device_type(), iter, p);
1481-
}
1458+
at::linalg_vector_norm_out(const_cast<Tensor&>(result), self, p, dim, keepdim, opt_dtype);
14821459
}
14831460

14841461
TORCH_IMPL_FUNC(norm_out)

torch/testing/_internal/common_methods_invocations.py

+7
Original file line numberDiff line numberDiff line change
@@ -16492,6 +16492,13 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1649216492
check_batched_forward_grad=False,
1649316493
supports_fwgrad_bwgrad=True,
1649416494
skips=(
16495+
# MPS has some mild accuracy issues for float16. We divide the tolerances by 10
16496+
DecorateInfo(
16497+
toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}),
16498+
'TestConsistency',
16499+
'test_output_match',
16500+
16501+
),
1649516502
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
1649616503
DecorateInfo(
1649716504
unittest.skip("Skipped!"),

0 commit comments

Comments
 (0)