Skip to content

Commit 1ed6b46

Browse files
authored
Error Handling: propagate status for ReleaseGilAndTransferData and XlaDataToTensors. (#9431)
1 parent 531c724 commit 1ed6b46

File tree

7 files changed

+21
-13
lines changed

7 files changed

+21
-13
lines changed

test/cpp/test_xla_sharding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a,
2929
torch::lazy::BackendDataPtr b,
3030
at::ScalarType element_type) {
3131
std::vector<at::Tensor> tensors =
32-
XlaDataToTensors({a, b}, {element_type, element_type});
32+
GetValueOrThrow(XlaDataToTensors({a, b}, {element_type, element_type}));
3333
return TensorCompare(tensors[0], tensors[1]);
3434
}
3535
} // namespace

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2712,7 +2712,7 @@ void InitXlaModuleBindings(py::module m) {
27122712
}
27132713

27142714
std::vector<at::Tensor> cpu_shards =
2715-
XlaDataToTensors(WrapXlaData(handles), element_types);
2715+
GetValueOrThrow(XlaDataToTensors(WrapXlaData(handles), element_types));
27162716
// Populate the resulting vector of shards and device strings
27172717
std::vector<std::vector<std::pair<at::Tensor, std::string>>> result;
27182718
int shards_per_tensor =

torch_xla/csrc/tensor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
4141
#include "torch_xla/csrc/runtime/sys_util.h"
4242
#include "torch_xla/csrc/runtime/xla_util.h"
43+
#include "torch_xla/csrc/status.h"
4344
#include "torch_xla/csrc/tensor_util.h"
4445
#include "torch_xla/csrc/torch_util.h"
4546
#include "torch_xla/csrc/xla_graph_executor.h"
@@ -512,7 +513,7 @@ at::Tensor XLATensor::ToTensor(bool detached) {
512513
// The GetXlaData() call will trigger an ApplyPendingGraph() if an IR
513514
// XlaNode is available on the tensor.
514515
std::vector<at::Tensor> tensors =
515-
XlaDataToTensors({GetXlaData()}, {dtype()});
516+
GetValueOrThrow(XlaDataToTensors({GetXlaData()}, {dtype()}));
516517
tensor = std::move(tensors.front());
517518
if (!detached) {
518519
SetTensorData(tensor);

torch_xla/csrc/tensor_util.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape,
896896
return literal;
897897
}
898898

899-
std::vector<xla::Literal> ReleaseGilAndTransferData(
899+
absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
900900
absl::Span<const torch::lazy::BackendDataPtr> xla_data) {
901901
// HACK: This method may be called outside of python (mainly in C++ tests) or
902902
// when the GIL is already released, so we must check both cases here. If
@@ -909,20 +909,24 @@ std::vector<xla::Literal> ReleaseGilAndTransferData(
909909
if (release_gil && Py_IsInitialized() && PyGILState_Check()) {
910910
save = PyEval_SaveThread();
911911
}
912-
std::vector<xla::Literal> literals =
913-
GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice(
914-
UnwrapXlaData(xla_data)));
912+
913+
XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * client,
914+
runtime::GetComputationClient());
915+
XLA_ASSIGN_OR_RETURN(std::vector<xla::Literal> literals,
916+
client->TransferFromDevice(UnwrapXlaData(xla_data)));
917+
915918
if (save) {
916919
PyEval_RestoreThread(save);
917920
}
918921

919922
return literals;
920923
}
921924

922-
std::vector<at::Tensor> XlaDataToTensors(
925+
absl::StatusOr<std::vector<at::Tensor>> XlaDataToTensors(
923926
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
924927
absl::Span<const at::ScalarType> dest_element_type) {
925-
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(xla_data);
928+
XLA_ASSIGN_OR_RETURN(std::vector<xla::Literal> literals,
929+
ReleaseGilAndTransferData(xla_data));
926930
std::vector<at::Tensor> tensors(literals.size());
927931
absl::BlockingCounter counter(literals.size());
928932
for (size_t i = 0; i < tensors.size(); ++i) {

torch_xla/csrc/tensor_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
2828
// Execution and data transfer are async in PJRT, so TransferFromDevice may
2929
// block until `DataPtr`s are ready. Release the GIL so other threads can
3030
// proceed and unblock any transfers or collective computations.
31-
std::vector<xla::Literal> ReleaseGilAndTransferData(
31+
absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
3232
absl::Span<const torch::lazy::BackendDataPtr> xla_data);
3333

3434
// TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice
35-
std::vector<at::Tensor> XlaDataToTensors(
35+
absl::StatusOr<std::vector<at::Tensor>> XlaDataToTensors(
3636
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
3737
absl::Span<const at::ScalarType> dest_element_type);
3838

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "torch_xla/csrc/runtime/computation_client.h"
1111
#include "torch_xla/csrc/runtime/debug_macros.h"
1212
#include "torch_xla/csrc/runtime/runtime.h"
13+
#include "torch_xla/csrc/status.h"
14+
#include "torch_xla/csrc/tensor_util.h"
1315

1416
namespace at {
1517
// This function is defined in the codegenerated RegisterDispatchKey.cpp file.
@@ -92,7 +94,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
9294
const torch::lazy::BackendDataPtr data,
9395
std::optional<at::ScalarType> logical_scalar_type) const override {
9496
// TODO(JackCaoG): handle the logical_scalar_type == nullptr case
95-
return XlaDataToTensors({data}, {*logical_scalar_type})[0];
97+
return GetValueOrThrow(XlaDataToTensors({data}, {*logical_scalar_type}))[0];
9698
}
9799

98100
std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,8 @@ std::vector<at::Tensor> XLAGraphExecutor::GetTensors(
497497
async != nullptr ? async->tensors_data
498498
: absl::Span<const torch::lazy::BackendDataPtr>());
499499

500-
std::vector<xla::Literal> literals = ReleaseGilAndTransferData(tensors_data);
500+
std::vector<xla::Literal> literals =
501+
GetValueOrThrow(ReleaseGilAndTransferData(tensors_data));
501502

502503
return FetchTensors(tensors, literals,
503504
async != nullptr ? &async->indices : nullptr);

0 commit comments

Comments
 (0)