Skip to content

Commit 2820f7c

Browse files
authored
Error Handling: return status value when loading PjRt dynamic plugin. (#9495)
1 parent 29ae4c7 commit 2820f7c

File tree

4 files changed

+60
-11
lines changed

4 files changed

+60
-11
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ function run_xla_op_tests2 {
227227
run_test "$_TEST_DIR/test_assume_pure_spmd.py"
228228
run_test "$_TEST_DIR/test_assume_pure_torch.py"
229229
run_test "$_TEST_DIR/test_dynamic_shapes_detector.py"
230+
run_test "$_TEST_DIR/test_runtime_client_initialization_error.py"
230231
}
231232

232233
function run_xla_op_tests3 {
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
import torch_xla
3+
import torch_xla.core.xla_env_vars as xenv
4+
import unittest
5+
6+
7+
class TestClientInitializationError(unittest.TestCase):
8+
9+
def test(self):
10+
11+
def initialize_client(device):
12+
os.environ[xenv.PJRT_DEVICE] = device
13+
14+
# The message does not change!
15+
# After the first call with DUMMY_DEVICE, all other calls will have
16+
# "DUMMY_DEVICE" in their message.
17+
message = (
18+
f"No PjRtPlugin registered for: DUMMY_DEVICE. "
19+
f"Make sure the environment variable {xenv.PJRT_DEVICE} is set "
20+
"to a correct device name.")
21+
22+
with self.assertRaisesRegex(RuntimeError, expected_regex=message):
23+
torch_xla._XLAC._init_computation_client()
24+
25+
# Run the initialization function the first time, ending up in an
26+
# exception thrown.
27+
initialize_client("DUMMY_DEVICE")
28+
29+
# Even if the device exists, this call should fail, since the result
30+
# of the first call is cached.
31+
initialize_client("CPU")
32+
33+
34+
if __name__ == '__main__':
35+
unittest.main()

torch_xla/csrc/runtime/pjrt_registry.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,21 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
6161
return allocator_config;
6262
}
6363

64-
std::shared_ptr<const PjRtPlugin> GetPjRtPlugin(
64+
absl::StatusOr<std::shared_ptr<const PjRtPlugin>> GetPjRtPlugin(
6565
const std::string& device_type) {
66-
auto plugin_path = pjrt_plugins_.find(device_type);
67-
return plugin_path != pjrt_plugins_.end() ? plugin_path->second : nullptr;
66+
auto entry = pjrt_plugins_.find(device_type);
67+
if (entry == pjrt_plugins_.end()) {
68+
std::string message = absl::StrCat(
69+
"No PjRtPlugin registered for: ", device_type,
70+
". Make sure the environment variable ", env::kEnvPjRtDevice,
71+
" is set to a correct device name. See "
72+
"https://github.com/pytorch/xla/blob/master/docs/source/"
73+
"contribute/plugins.md for more information on "
74+
"implementing and registering a new "
75+
"plugin.");
76+
return XLA_ERROR_WITH_LOCATION(absl::FailedPreconditionError(message));
77+
}
78+
return entry->second;
6879
}
6980

7081
} // namespace
@@ -83,10 +94,10 @@ InitializePjRt(const std::string& device_type) {
8394

8495
if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false) &&
8596
device_type != "CPU") {
86-
std::shared_ptr<const PjRtPlugin> plugin = GetPjRtPlugin(device_type);
97+
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;
98+
XLA_ASSIGN_OR_RETURN(std::shared_ptr<const PjRtPlugin> plugin,
99+
GetPjRtPlugin(device_type));
87100
if (plugin) {
88-
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;
89-
90101
// Init the absl logging to avoid the log spam.
91102
absl::InitializeLog();
92103

torch_xla/csrc/runtime/runtime.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ static std::atomic<bool> g_computation_client_initialized(false);
1919
// Can only be called when g_computation_client_initialized is false.
2020
static absl::StatusOr<ComputationClient * absl_nonnull>
2121
InitializeComputationClient() {
22-
ABSL_CHECK(!g_computation_client_initialized)
23-
<< "InitializeComputationClient() can only be called once.";
24-
g_computation_client_initialized = true;
25-
2622
if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
2723
tsl::testing::InstallStacktraceHandler();
2824
}
@@ -37,13 +33,19 @@ InitializeComputationClient() {
3733
absl::FailedPreconditionError("$PJRT_DEVICE is not set."));
3834
}
3935

36+
ABSL_CHECK(!g_computation_client_initialized)
37+
<< "ComputationClient can only be initialized once.";
38+
4039
std::unique_ptr<ComputationClient> client;
4140
if (use_ifrt) {
4241
XLA_ASSIGN_OR_RETURN(client, IfrtComputationClient::Create());
4342
} else {
4443
XLA_ASSIGN_OR_RETURN(client, PjRtComputationClient::Create());
4544
}
4645

46+
// Set only if we actually successfully initialized a client.
47+
g_computation_client_initialized = true;
48+
4749
return client.release();
4850
}
4951

@@ -59,7 +61,7 @@ const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient() {
5961
}
6062

6163
ComputationClient* absl_nonnull GetComputationClientOrDie() {
62-
return GetComputationClient().value();
64+
return GetValueOrThrow(GetComputationClient());
6365
}
6466

6567
ComputationClient* GetComputationClientIfInitialized() {

0 commit comments

Comments
 (0)