@@ -115,10 +115,12 @@ std::vector<std::string> PjRtComputationClient::PjRtDevicesToString(
115
115
return strs;
116
116
}
117
117
118
- PjRtComputationClient::PjRtComputationClient () {
118
+ PjRtComputationClient::PjRtComputationClient (PrivateUse) {}
119
+
120
+ absl::Status PjRtComputationClient::Initialize () {
119
121
std::string device_type = sys_util::GetEnvString (env::kEnvPjRtDevice , " " );
120
- std::tie (client_, coordinator_) =
121
- GetValueOrThrow ( InitializePjRt (device_type));
122
+ XLA_ASSIGN_OR_RETURN ( std::tie (client_, coordinator_),
123
+ InitializePjRt (device_type));
122
124
123
125
// PjRtDevice IDs are not guaranteed to be dense, so we need to track
124
126
// a device's global ordinal separately from its device ID. Order the
@@ -137,6 +139,15 @@ PjRtComputationClient::PjRtComputationClient() {
137
139
auto tracked_devices = GetLocalDevices ();
138
140
tracked_devices.emplace_back (spmd_device_str);
139
141
operation_manager_ = std::move (OperationManager (std::move (tracked_devices)));
142
+
143
+ return absl::OkStatus ();
144
+ }
145
+
146
+ absl::StatusOr<absl_nonnull std::unique_ptr<PjRtComputationClient>>
147
+ PjRtComputationClient::Create () {
148
+ auto pjrt_client = std::make_unique<PjRtComputationClient>(PrivateUse ());
149
+ XLA_RETURN_IF_ERROR (pjrt_client->Initialize ());
150
+ return std::move (pjrt_client);
140
151
}
141
152
142
153
PjRtComputationClient::~PjRtComputationClient () {
@@ -837,7 +848,7 @@ PjRtComputationClient::ExecuteReplicated(
837
848
argument_handles[d][i] = shard->buffer .get ();
838
849
}
839
850
counter.DecrementCount ();
840
- };
851
+ }
841
852
});
842
853
counter.Wait ();
843
854
}
@@ -962,7 +973,7 @@ int PjRtComputationClient::GetNumProcesses() const {
962
973
}
963
974
964
975
return max_process_index + 1 ;
965
- };
976
+ }
966
977
967
978
std::string PjRtComputationClient::GetDeviceKind (const std::string& device) {
968
979
return std::string (StringToPjRtDevice (device)->device_kind ());
0 commit comments