Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/test_ops_error_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,21 @@ def test():
expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2)."""
)

def test_uniform__raises_error_on_invalid_range(self):
device = torch_xla.device()
a = torch.empty(5, 5, device=device)
from_ = 5.
to_ = 2.

def test():
return a.uniform_(from_, to_)

self.assertExpectedRaisesInline(
exc_type=RuntimeError,
callable=test,
expect="""uniform_(): expected `from` (5) to be smaller or equal `to` (2)."""
)


if __name__ == "__main__":
unittest.main()
5 changes: 3 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3916,8 +3916,9 @@ at::Tensor& XLANativeFunctions::uniform_(
return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call(
self, from, to, generator);
}
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
tensor_methods::uniform_(xla_self, from, to);
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
bridge::GetXlaTensor(self));
XLA_THROW_IF_ERROR(tensor_methods::uniform_(xla_self, from, to));
return self;
}

Expand Down
20 changes: 15 additions & 5 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,15 @@ absl::Status CheckStackAtLeastOneTensor(
return absl::OkStatus();
}

absl::Status CheckUniformRangeIsValid(double from, double to) {
if (from > to) {
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
absl::StrCat("uniform_(): expected `from` (", from,
") to be smaller or equal `to` (", to, ").")));
}
return absl::OkStatus();
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -3727,15 +3736,16 @@ std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim) {
return slices;
}

void uniform_(XLATensorPtr& input, double from, double to) {
XLA_CHECK_LE(from, to);
auto input_shape = input->shape();
absl::Status uniform_(XLATensorPtr& input, double from, double to) {
XLA_RETURN_IF_ERROR(CheckUniformRangeIsValid(from, to));
xla::Shape input_shape = input->shape();
input->SetInPlaceIrValue(torch_xla::MakeNode<Uniform>(
XLAGraphExecutor::Get()->GetIrValueForScalar(
from, input_shape.get().element_type(), input->GetDevice()),
from, input_shape.element_type(), input->GetDevice()),
XLAGraphExecutor::Get()->GetIrValueForScalar(
to, input_shape.get().element_type(), input->GetDevice()),
to, input_shape.element_type(), input->GetDevice()),
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
return absl::OkStatus();
}

XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ std::tuple<XLATensorPtr, XLATensorPtr> triangular_solve(
// removed.
std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim);

void uniform_(XLATensorPtr& input, double from, double to);
absl::Status uniform_(XLATensorPtr& input, double from, double to);

// Insert a dimension of size one at the specified position.
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim);
Expand Down
Loading