Skip to content

Commit 071ddfa

Browse files
authored
Fix XLA type of the view_as_real (#8370)
1 parent c4f2771 commit 071ddfa

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

test/test_operations.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,30 @@ def test_sgn(self):
793793
xla_a = t.to(xla_device).sgn()
794794
self.assertEqual(a.data, xla_a.data.cpu())
795795

796+
@skipIfFunctionalizationDisabled("view_as_real unsupported")
797+
def test_view_as_real_c64(self):
798+
xla_device = torch_xla.device()
799+
x = torch.randn(4, dtype=torch.cfloat, device=xla_device)
800+
real = torch.view_as_real(x)
801+
self.assertEqual(real.dtype, torch.float32)
802+
# XLA type of the real needs to be f32 as well
803+
self.assertIn("f32[4,2]", torch_xla._XLAC._get_xla_tensor_debug_info(real))
804+
# HLO generated needs to have type f32 as well
805+
self.assertIn("f32[4,2]",
806+
torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3])
807+
808+
@skipIfFunctionalizationDisabled("view_as_real unsupported")
809+
def test_view_as_real_c128(self):
810+
xla_device = torch_xla.device()
811+
x = torch.randn(4, dtype=torch.cdouble, device=xla_device)
812+
real = torch.view_as_real(x)
813+
self.assertEqual(real.dtype, torch.float64)
814+
# XLA type of the real needs to be f32 as well
815+
self.assertIn("f64[4,2]", torch_xla._XLAC._get_xla_tensor_debug_info(real))
816+
# HLO generated needs to have type f32 as well
817+
self.assertIn("f64[4,2]",
818+
torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3])
819+
796820
def test_index_put(self):
797821
xla_device = xm.xla_device()
798822
a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32)

torch_xla/csrc/ops/ops.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -998,11 +998,24 @@ torch::lazy::NodePtr ViewAsRealCopy(const torch::lazy::Value& input) {
998998
return node.ReturnOp(BuildStack({real, imag}, input_shape.rank()), loctx);
999999
};
10001000

1001-
xla::Shape result_shape = GetXlaShape(input);
1002-
result_shape.add_dimensions(2);
1001+
xla::Shape input_shape = GetXlaShape(input);
1002+
xla::Shape res_shape;
1003+
switch (input_shape.element_type()) {
1004+
case xla::PrimitiveType::C64:
1005+
res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32,
1006+
input_shape.dimensions());
1007+
break;
1008+
case xla::PrimitiveType::C128:
1009+
res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F64,
1010+
input_shape.dimensions());
1011+
break;
1012+
default:
1013+
XLA_ERROR() << "input shape type not supported: " << input_shape;
1014+
}
1015+
res_shape.add_dimensions(2);
10031016

10041017
return GenericOp(torch::lazy::OpKind(at::aten::view_as_real_copy), {input},
1005-
result_shape, std::move(lower_fn));
1018+
res_shape, std::move(lower_fn));
10061019
}
10071020

10081021
torch::lazy::NodePtr Rsub(const torch::lazy::Value& input,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3567,7 +3567,7 @@ XLATensorPtr view_as_complex_copy(const XLATensorPtr& input) {
35673567

35683568
XLATensorPtr view_as_real_copy(const XLATensorPtr& input) {
35693569
return input->CreateFrom(ViewAsRealCopy(input->GetIrValue()),
3570-
at::ScalarType::Float);
3570+
/*logical_element_type=*/std::nullopt);
35713571
}
35723572

35733573
XLATensorPtr var(const XLATensorPtr& input, std::vector<int64_t> dimensions,

0 commit comments

Comments
 (0)