diff --git a/test/test_operations.py b/test/test_operations.py index 545831acf49..fc6765d0483 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -793,6 +793,30 @@ def test_sgn(self): xla_a = t.to(xla_device).sgn() self.assertEqual(a.data, xla_a.data.cpu()) + @skipIfFunctionalizationDisabled("view_as_real unsupported") + def test_view_as_real_c64(self): + xla_device = torch_xla.device() + x = torch.randn(4, dtype=torch.cfloat, device=xla_device) + real = torch.view_as_real(x) + self.assertEqual(real.dtype, torch.float32) + # XLA type of the real needs to be f32 as well + self.assertIn("f32[4,2]", torch_xla._XLAC._get_xla_tensor_debug_info(real)) + # HLO generated needs to have type f32 as well + self.assertIn("f32[4,2]", + torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3]) + + @skipIfFunctionalizationDisabled("view_as_real unsupported") + def test_view_as_real_c128(self): + xla_device = torch_xla.device() + x = torch.randn(4, dtype=torch.cdouble, device=xla_device) + real = torch.view_as_real(x) + self.assertEqual(real.dtype, torch.float64) + # XLA type of the real needs to be f32 as well + self.assertIn("f64[4,2]", torch_xla._XLAC._get_xla_tensor_debug_info(real)) + # HLO generated needs to have type f32 as well + self.assertIn("f64[4,2]", + torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3]) + def test_index_put(self): xla_device = xm.xla_device() a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32) diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index c8d46edc311..84a3f623200 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -998,11 +998,24 @@ torch::lazy::NodePtr ViewAsRealCopy(const torch::lazy::Value& input) { return node.ReturnOp(BuildStack({real, imag}, input_shape.rank()), loctx); }; - xla::Shape result_shape = GetXlaShape(input); - result_shape.add_dimensions(2); + xla::Shape input_shape = GetXlaShape(input); + xla::Shape res_shape; + switch (input_shape.element_type()) { + case xla::PrimitiveType::C64: + res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, + input_shape.dimensions()); + break; + case xla::PrimitiveType::C128: + res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F64, + input_shape.dimensions()); + break; + default: + XLA_ERROR() << "input shape type not supported: " << input_shape; + } + res_shape.add_dimensions(2); return GenericOp(torch::lazy::OpKind(at::aten::view_as_real_copy), {input}, - result_shape, std::move(lower_fn)); + res_shape, std::move(lower_fn)); } torch::lazy::NodePtr Rsub(const torch::lazy::Value& input, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 9c34f84c44a..f73b5ebb3cf 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -3567,7 +3567,7 @@ XLATensorPtr view_as_complex_copy(const XLATensorPtr& input) { XLATensorPtr view_as_real_copy(const XLATensorPtr& input) { return input->CreateFrom(ViewAsRealCopy(input->GetIrValue()), - at::ScalarType::Float); + /*logical_element_type=*/std::nullopt); } XLATensorPtr var(const XLATensorPtr& input, std::vector dimensions,