Skip to content

Commit

Permalink
Fix XLA type of the view_as_real (#8370)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Nov 12, 2024
1 parent c4f2771 commit 071ddfa
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 4 deletions.
24 changes: 24 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,30 @@ def test_sgn(self):
xla_a = t.to(xla_device).sgn()
self.assertEqual(a.data, xla_a.data.cpu())

@skipIfFunctionalizationDisabled("view_as_real unsupported")
def test_view_as_real_c64(self):
xla_device = torch_xla.device()
x = torch.randn(4, dtype=torch.cfloat, device=xla_device)
real = torch.view_as_real(x)
self.assertEqual(real.dtype, torch.float32)
# XLA type of the real needs to be f32 as well
self.assertIn("f32[4,2]", torch_xla._XLAC._get_xla_tensor_debug_info(real))
# HLO generated needs to have type f32 as well
self.assertIn("f32[4,2]",
torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3])

@skipIfFunctionalizationDisabled("view_as_real unsupported")
def test_view_as_real_c128(self):
xla_device = torch_xla.device()
x = torch.randn(4, dtype=torch.cdouble, device=xla_device)
real = torch.view_as_real(x)
self.assertEqual(real.dtype, torch.float64)
# XLA type of the real needs to be f32 as well
self.assertIn("f64[4,2]", torch_xla._XLAC._get_xla_tensor_debug_info(real))
# HLO generated needs to have type f32 as well
self.assertIn("f64[4,2]",
torch_xla._XLAC._get_xla_tensors_text([real]).split('\n')[-3])

def test_index_put(self):
xla_device = xm.xla_device()
a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32)
Expand Down
19 changes: 16 additions & 3 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,24 @@ torch::lazy::NodePtr ViewAsRealCopy(const torch::lazy::Value& input) {
return node.ReturnOp(BuildStack({real, imag}, input_shape.rank()), loctx);
};

xla::Shape result_shape = GetXlaShape(input);
result_shape.add_dimensions(2);
xla::Shape input_shape = GetXlaShape(input);
xla::Shape res_shape;
switch (input_shape.element_type()) {
case xla::PrimitiveType::C64:
res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32,
input_shape.dimensions());
break;
case xla::PrimitiveType::C128:
res_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F64,
input_shape.dimensions());
break;
default:
XLA_ERROR() << "input shape type not supported: " << input_shape;
}
res_shape.add_dimensions(2);

return GenericOp(torch::lazy::OpKind(at::aten::view_as_real_copy), {input},
result_shape, std::move(lower_fn));
res_shape, std::move(lower_fn));
}

torch::lazy::NodePtr Rsub(const torch::lazy::Value& input,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3567,7 +3567,7 @@ XLATensorPtr view_as_complex_copy(const XLATensorPtr& input) {

XLATensorPtr view_as_real_copy(const XLATensorPtr& input) {
return input->CreateFrom(ViewAsRealCopy(input->GetIrValue()),
at::ScalarType::Float);
/*logical_element_type=*/std::nullopt);
}

XLATensorPtr var(const XLATensorPtr& input, std::vector<int64_t> dimensions,
Expand Down

0 comments on commit 071ddfa

Please sign in to comment.