Skip to content

Commit fc8dcd5

Browse files
committed
Refactor computation clients to use factory pattern
Update IFRT and PjRt computation clients to use `Create()` factory methods: - Replace constructors with factory methods that return `StatusOr<T>` - Use `ConsumeAndMaybeThrow` for `XlaCoordinator::Create` integration - Improved error handling with proper status propagation
1 parent 40db714 commit fc8dcd5

8 files changed

+78
-15
lines changed

torch_xla/csrc/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
":env_vars",
2525
":ifrt_computation_client",
2626
":pjrt_computation_client",
27+
"//torch_xla/csrc:status",
2728
"@com_google_absl//absl/log:absl_check",
2829
"@tsl//tsl/platform:stacktrace",
2930
],

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,13 @@ std::vector<std::string> IfrtComputationClient::IfrtDevicesToString(
120120
return strs;
121121
}
122122

123-
IfrtComputationClient::IfrtComputationClient() {
123+
IfrtComputationClient::IfrtComputationClient(PrivateUse) {}
124+
125+
absl::Status IfrtComputationClient::Initialize() {
124126
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
125127
std::unique_ptr<xla::PjRtClient> pjrt_client;
126-
std::tie(pjrt_client, coordinator_) =
127-
GetValueOrThrow(InitializePjRt(device_type));
128+
XLA_ASSIGN_OR_RETURN(std::tie(pjrt_client, coordinator_),
129+
InitializePjRt(device_type));
128130

129131
client_ = xla::ifrt::PjRtClient::Create(std::move(pjrt_client));
130132

@@ -146,6 +148,15 @@ IfrtComputationClient::IfrtComputationClient() {
146148
auto tracked_devices = GetLocalDevices();
147149
tracked_devices.emplace_back(spmd_device_str);
148150
operation_manager_ = std::move(OperationManager(std::move(tracked_devices)));
151+
152+
return absl::OkStatus();
153+
}
154+
155+
absl::StatusOr<absl_nonnull std::unique_ptr<IfrtComputationClient>>
156+
IfrtComputationClient::Create() {
157+
auto ifrt_client = std::make_unique<IfrtComputationClient>(PrivateUse());
158+
XLA_RETURN_IF_ERROR(ifrt_client->Initialize());
159+
return std::move(ifrt_client);
149160
}
150161

151162
IfrtComputationClient::~IfrtComputationClient() {

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,17 @@ namespace torch_xla {
2727
namespace runtime {
2828

2929
class IfrtComputationClient : public ComputationClient {
30+
private:
31+
// Private struct for making the constructor private, but still callable
32+
// as: `std::make_unique<IfrtComputationClient>(PrivateUse())`.
33+
struct PrivateUse {
34+
// Constructor needs to be explicit for disallowing implicit construction
35+
// from `{}`.
36+
explicit PrivateUse() = default;
37+
};
38+
3039
public:
31-
IfrtComputationClient();
40+
IfrtComputationClient(PrivateUse);
3241
~IfrtComputationClient();
3342

3443
DataPtr CreateDataPlaceholder(
@@ -165,7 +174,15 @@ class IfrtComputationClient : public ComputationClient {
165174
XLA_ERROR() << __FUNCTION__ << " not implemented";
166175
}
167176

177+
// Creates a new instance of IfrtComputationClient and initializes it.
178+
static absl::StatusOr<absl_nonnull std::unique_ptr<IfrtComputationClient>>
179+
Create();
180+
168181
private:
182+
// Convenience function called by `Create()` that initializes the current
183+
// IfrtComputationClient.
184+
absl::Status Initialize();
185+
169186
std::shared_ptr<xla::ifrt::PjRtClient> client_;
170187
std::unique_ptr<XlaCoordinator> coordinator_;
171188
// global_ordinals_ tracks a map from PjRtDeviceId to the device's

torch_xla/csrc/runtime/ifrt_computation_client_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include <string>
77
#include <vector>
88

9-
#include "absl/status/status.h"
109
#include "torch_xla/csrc/runtime/computation_client.h"
1110
#include "torch_xla/csrc/runtime/tensor_source.h"
11+
#include "torch_xla/csrc/status.h"
1212
#include "tsl/platform/env.h"
1313
#include "tsl/platform/logging.h"
1414
#include "tsl/platform/statusor.h"
@@ -36,7 +36,7 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
3636
TEST(PjRtComputationClientTest, Init) {
3737
// Get a CPU client.
3838
tsl::setenv("PJRT_DEVICE", "CPU", true);
39-
auto client = std::make_unique<IfrtComputationClient>();
39+
auto client = GetValueOrThrow(IfrtComputationClient::Create());
4040
std::string device = client->GetDefaultDevice();
4141

4242
// Compose a computation.

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,12 @@ std::vector<std::string> PjRtComputationClient::PjRtDevicesToString(
115115
return strs;
116116
}
117117

118-
PjRtComputationClient::PjRtComputationClient() {
118+
PjRtComputationClient::PjRtComputationClient(PrivateUse) {}
119+
120+
absl::Status PjRtComputationClient::Initialize() {
119121
std::string device_type = sys_util::GetEnvString(env::kEnvPjRtDevice, "");
120-
std::tie(client_, coordinator_) =
121-
GetValueOrThrow(InitializePjRt(device_type));
122+
XLA_ASSIGN_OR_RETURN(std::tie(client_, coordinator_),
123+
InitializePjRt(device_type));
122124

123125
// PjRtDevice IDs are not guaranteed to be dense, so we need to track
124126
// a device's global ordinal separately from its device ID. Order the
@@ -137,6 +139,15 @@ PjRtComputationClient::PjRtComputationClient() {
137139
auto tracked_devices = GetLocalDevices();
138140
tracked_devices.emplace_back(spmd_device_str);
139141
operation_manager_ = std::move(OperationManager(std::move(tracked_devices)));
142+
143+
return absl::OkStatus();
144+
}
145+
146+
absl::StatusOr<absl_nonnull std::unique_ptr<PjRtComputationClient>>
147+
PjRtComputationClient::Create() {
148+
auto pjrt_client = std::make_unique<PjRtComputationClient>(PrivateUse());
149+
XLA_RETURN_IF_ERROR(pjrt_client->Initialize());
150+
return std::move(pjrt_client);
140151
}
141152

142153
PjRtComputationClient::~PjRtComputationClient() {

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,17 @@ namespace torch_xla {
2525
namespace runtime {
2626

2727
class PjRtComputationClient : public ComputationClient {
28+
private:
29+
// Private struct for making the constructor private, but still callable
30+
// as: `std::make_unique<PjRtComputationClient>(PrivateUse())`.
31+
struct PrivateUse {
32+
// Constructor needs to be explicit for disallowing implicit construction
33+
// from `{}`.
34+
explicit PrivateUse() = default;
35+
};
36+
2837
public:
29-
PjRtComputationClient();
38+
PjRtComputationClient(PrivateUse);
3039
~PjRtComputationClient() override;
3140

3241
DataPtr CreateDataPlaceholder(
@@ -163,6 +172,10 @@ class PjRtComputationClient : public ComputationClient {
163172
void OnReadyCallback(DataPtr data,
164173
const std::function<void()>& callback) override;
165174

175+
// Creates a new instance of PjRtComputationClient and initializes it.
176+
static absl::StatusOr<absl_nonnull std::unique_ptr<PjRtComputationClient>>
177+
Create();
178+
166179
private:
167180
friend class PjRtComputationClientTest;
168181

@@ -172,6 +185,10 @@ class PjRtComputationClient : public ComputationClient {
172185
fake_xla_compile_ = std::move(function);
173186
}
174187

188+
// Convenience function called by `Create()` that initializes the current
189+
// PjRtComputationClient.
190+
absl::Status Initialize();
191+
175192
std::unique_ptr<xla::PjRtClient> client_;
176193
std::unique_ptr<XlaCoordinator> coordinator_;
177194
// global_ordinals_ tracks a map from PjRtDeviceId to the device's

torch_xla/csrc/runtime/pjrt_computation_client_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
#include <string>
99
#include <vector>
1010

11-
#include "absl/status/status.h"
1211
#include "torch_xla/csrc/runtime/computation_client.h"
1312
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
1413
#include "torch_xla/csrc/runtime/tensor_source.h"
14+
#include "torch_xla/csrc/status.h"
1515
#include "xla/hlo/builder/xla_builder.h"
1616
#include "xla/hlo/builder/xla_computation.h"
1717
#include "xla/literal.h"
@@ -26,7 +26,7 @@ class PjRtComputationClientTest : public ::testing::Test {
2626
PjRtComputationClientTest() {
2727
// Get a CPU client.
2828
tsl::setenv("PJRT_DEVICE", "CPU", true);
29-
client_ = std::make_unique<PjRtComputationClient>();
29+
client_ = GetValueOrThrow(PjRtComputationClient::Create());
3030
device_ = client_->GetDefaultDevice();
3131
}
3232

torch_xla/csrc/runtime/runtime.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "torch_xla/csrc/runtime/env_vars.h"
88
#include "torch_xla/csrc/runtime/ifrt_computation_client.h"
99
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
10+
#include "torch_xla/csrc/status.h"
1011
#include "tsl/platform/stacktrace_handler.h"
1112

1213
namespace torch_xla::runtime {
@@ -32,13 +33,18 @@ InitializeComputationClient() {
3233
// static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false);
3334
const bool use_ifrt = false;
3435
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") == "") {
35-
return absl::FailedPreconditionError("$PJRT_DEVICE is not set.");
36+
return XLA_ERROR_WITH_LOCATION(
37+
absl::FailedPreconditionError("$PJRT_DEVICE is not set."));
3638
}
3739

40+
std::unique_ptr<ComputationClient> client;
3841
if (use_ifrt) {
39-
return new IfrtComputationClient();
42+
XLA_ASSIGN_OR_RETURN(client, IfrtComputationClient::Create());
43+
} else {
44+
XLA_ASSIGN_OR_RETURN(client, PjRtComputationClient::Create());
4045
}
41-
return new PjRtComputationClient();
46+
47+
return client.release();
4248
}
4349

4450
const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient() {

0 commit comments

Comments
 (0)