Skip to content

Commit 90fc5e3

Browse files
Anjali Sridharyifeif
Anjali Sridhar
authored andcommitted
Allow is_initialized and initializer to be called on MirroredVariables and TowerLocalVariables.
PiperOrigin-RevId: 203520287
1 parent d64754c commit 90fc5e3

File tree

84 files changed

+10517
-239
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+10517
-239
lines changed

tensorflow/compiler/jit/xla_compilation_cache.cc

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,7 @@ namespace tensorflow {
4040
XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
4141
DeviceType device_type)
4242
: client_(client), device_type_(std::move(device_type)) {}
43-
XlaCompilationCache::~XlaCompilationCache() {
44-
// Ensure any use of our programs have completed by waiting for all stream
45-
// executors to complete.
46-
for (auto* executor : client_->backend().stream_executors()) {
47-
bool ok = executor->SynchronizeAllActivity();
48-
if (!ok) {
49-
LOG(ERROR) << "Error synchronizing activity while waiting for all "
50-
"programs to complete";
51-
}
52-
}
53-
// TODO(b/110813685): Think about the program ownership model. Programs are
54-
// currently owned by the compilation cache which means we must wait for
55-
// program completion in the destructor. There are multiple compilation caches
56-
// around, which complicates things a little. Perhaps having programs be
57-
// shared_ptrs (an invasive change) would make the model easier to reason
58-
// about?
59-
}
43+
XlaCompilationCache::~XlaCompilationCache() = default;
6044

6145
string XlaCompilationCache::DebugString() {
6246
return "XLA JIT compilation cache";

tensorflow/compiler/jit/xla_device_context.cc

Lines changed: 44 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -67,53 +67,36 @@ Status XlaTransferManager::TransferLiteralToDevice(
6767
xla::Shape xla_shape;
6868
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
6969
host_tensor.shape(), &xla_shape));
70-
// Create a reference to hold onto host_tensor until after the literal has
71-
// been transferred. Also make sure the literal exists until the function
72-
// asynchronously completes, as it will be wrapped in an xla::LiteralSlice.
73-
TensorReference ref(host_tensor);
74-
auto literal = std::make_shared<xla::BorrowingLiteral>(
70+
xla::BorrowingLiteral literal(
7571
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
7672

7773
const xla::ShapedBuffer& shaped_buffer =
7874
XlaTensor::FromTensor(device_tensor)->shaped_buffer();
79-
VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
75+
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
8076
<< shaped_buffer.ToString();
81-
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
82-
stream_, *literal, shaped_buffer));
83-
// Unref the host tensor, and capture the literal shared_ptr too so it goes
84-
// out of scope when the lambda completes.
85-
stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
86-
return Status::OK();
77+
return transfer_manager_->TransferLiteralToDevice(stream_, literal,
78+
shaped_buffer);
8779
}
8880

89-
void XlaTransferManager::TransferLiteralFromDevice(
90-
Tensor* host_tensor, const Tensor& device_tensor,
91-
const StatusCallback& done) const {
81+
Status XlaTransferManager::TransferLiteralFromDevice(
82+
Tensor* host_tensor, const Tensor& device_tensor) const {
9283
const xla::ShapedBuffer& shaped_buffer =
9384
XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
9485

95-
TensorReference ref(device_tensor);
96-
transfer_manager_->TransferLiteralFromDevice(
97-
stream_, shaped_buffer,
98-
[=, &shaped_buffer](
99-
xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) {
100-
ref.Unref();
101-
done([&]() -> Status {
102-
TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or));
103-
VLOG(1) << "Transfer from device as literal: " << literal->ToString()
104-
<< " " << shaped_buffer.ToString();
105-
Tensor tensor;
106-
TF_RETURN_IF_ERROR(
107-
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
108-
// Reshape the tensor back to its declared shape.
109-
Status status;
110-
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
111-
status = errors::Internal(
112-
"Tensor::CopyFrom failed when copying from XLA device to CPU");
113-
}
114-
return status;
115-
}());
116-
});
86+
TF_ASSIGN_OR_RETURN(
87+
std::unique_ptr<xla::Literal> literal,
88+
transfer_manager_->TransferLiteralFromDevice(stream_, shaped_buffer));
89+
VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
90+
<< shaped_buffer.ToString();
91+
Tensor tensor;
92+
TF_RETURN_IF_ERROR(
93+
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
94+
// Reshape the tensor back to its declared shape.
95+
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
96+
return errors::Internal(
97+
"Tensor::CopyFrom failed when copying from XLA device to CPU");
98+
}
99+
return Status::OK();
117100
}
118101

119102
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
@@ -136,7 +119,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
136119
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
137120
CHECK(xla_tensor);
138121

139-
Status status;
140122
xla::StatusOr<TensorShape> shape_or_status = shape_representation_fn_(
141123
device_tensor->shape(), device_tensor->dtype());
142124
if (!shape_or_status.ok()) {
@@ -145,14 +127,16 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
145127
}
146128
TensorShape shape = shape_or_status.ValueOrDie();
147129
if (!xla_tensor->has_shaped_buffer()) {
148-
status = xla_tensor->AllocateShapedBuffer(
130+
Status s = xla_tensor->AllocateShapedBuffer(
149131
device_tensor->dtype(), shape, client_,
150132
stream_->parent()->device_ordinal());
151-
if (!status.ok()) {
152-
return done(status);
133+
if (!s.ok()) {
134+
done(s);
135+
return;
153136
}
154137
}
155138

139+
Status status;
156140
if (transfer_as_literal_) {
157141
Tensor reshaped_cpu_tensor;
158142
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
@@ -205,8 +189,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
205189

206190
Status status;
207191
if (transfer_as_literal_) {
208-
TransferLiteralFromDevice(cpu_tensor, *device_tensor, done);
209-
return;
192+
status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
210193
} else {
211194
stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
212195
// TODO(hpucha): Make this asynchronous.
@@ -216,8 +199,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
216199
"Failed to complete data transfer on stream %p: %s", stream_,
217200
block_status.error_message().c_str());
218201
}
219-
done(status);
220202
}
203+
204+
done(status);
221205
return;
222206
}
223207

@@ -228,8 +212,8 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
228212
void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
229213
Tensor* dst_tensor,
230214
const StatusCallback& done) {
231-
// Perform memory allocation now, and enqueue the device-to-device transfer.
232-
Status status = [&]() -> Status {
215+
// TODO(phawkins): replace this code with an asynchronous implementation.
216+
auto body = [&]() {
233217
if (src_tensor.NumElements() == 0) {
234218
return Status::OK();
235219
}
@@ -245,20 +229,21 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
245229
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
246230
stream_->parent()->device_ordinal()));
247231
}
248-
auto from_iter = xla_src->shaped_buffer().buffers().begin();
249-
auto to_iter = xla_dst->shaped_buffer().buffers().begin();
250-
for (auto end_iter = xla_src->shaped_buffer().buffers().end();
251-
from_iter != end_iter; ++from_iter, ++to_iter) {
252-
stream_->ThenMemcpyD2D(&to_iter->second, from_iter->second,
253-
to_iter->second.size());
254-
}
232+
TF_RETURN_IF_ERROR(
233+
xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
234+
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
235+
const se::DeviceMemoryBase& from_buffer =
236+
xla_src->shaped_buffer().buffers().element(index);
237+
CHECK_EQ(buffer->size(), from_buffer.size());
238+
if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer,
239+
buffer->size())) {
240+
return errors::Internal("Device to device memcpy failed");
241+
}
242+
return Status::OK();
243+
}));
255244
return Status::OK();
256-
}();
257-
if (!status.ok()) {
258-
return done(status);
259-
} else {
260-
stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
261-
}
245+
};
246+
done(body());
262247
}
263248

264249
XlaDeviceContext::XlaDeviceContext(

tensorflow/compiler/jit/xla_device_context.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,8 @@ class XlaTransferManager {
6464
private:
6565
Status TransferLiteralToDevice(const Tensor& host_tensor,
6666
Tensor* device_tensor) const;
67-
void TransferLiteralFromDevice(Tensor* host_tensor,
68-
const Tensor& device_tensor,
69-
const StatusCallback& done) const;
67+
Status TransferLiteralFromDevice(Tensor* host_tensor,
68+
const Tensor& device_tensor) const;
7069

7170
// Stream obtained from a Device, used to transfer tensors between
7271
// CPU and device.

tensorflow/compiler/xla/service/executable.cc

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
8282

8383
StatusOr<ScopedShapedBuffer> return_value =
8484
ExecuteOnStream(run_options, arguments, profile_ptr.get());
85-
if (!return_value.status().ok()) {
86-
if (profile != nullptr) {
87-
// Ensure the ThenStartTimer call has completed before we destroy timer.
88-
// We already have a failure status to return, so just log this if it
89-
// fails.
90-
Status status = stream->BlockHostUntilDone();
91-
if (!status.ok()) {
92-
LOG(ERROR) << "Failed to BlockHostUntilDone: " << status;
93-
}
94-
}
95-
return return_value.status();
96-
}
85+
TF_RETURN_IF_ERROR(return_value.status());
9786

9887
if (profile != nullptr) {
9988
VLOG(1) << "enqueueing 'stop timer' and blocking host until done...";

tensorflow/compiler/xla/service/hlo_runner.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,8 @@ StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
180180

181181
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
182182
CreateExecutable(std::move(module), run_hlo_passes));
183-
TF_ASSIGN_OR_RETURN(
184-
ScopedShapedBuffer retval,
185-
executable->ExecuteOnStreamWrapper(&service_run_options,
186-
/*profile=*/profile, arguments));
187-
TF_RETURN_IF_ERROR(stream.BlockHostUntilDone());
188-
return std::move(retval);
183+
return executable->ExecuteOnStreamWrapper(&service_run_options,
184+
/*profile=*/profile, arguments);
189185
}
190186

191187
StatusOr<ScopedShapedBuffer> HloRunner::ExecuteWithDeviceBuffers(
@@ -313,7 +309,6 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
313309

314310
std::vector<std::unique_ptr<Literal>> exec_results;
315311
for (int64 i = 0; i < options.num_replicas; ++i) {
316-
TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
317312
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
318313
backend().transfer_manager()->TransferLiteralFromDevice(
319314
streams[i].get(), results[i]));

tensorflow/compiler/xla/tests/local_client_execute_test.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -772,10 +772,6 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
772772
ScopedShapedBuffer result =
773773
executable->Run({&x_array}, DefaultExecutableRunOptions())
774774
.ConsumeValueOrDie();
775-
ASSERT_IS_OK(local_client_->mutable_backend()
776-
->BorrowStream(0)
777-
.ValueOrDie()
778-
->BlockHostUntilDone());
779775

780776
LiteralTestUtil::ExpectR1Near<float>(
781777
{2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);

tensorflow/compiler/xla/tests/local_client_test_base.cc

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,7 @@ StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
189189
TF_ASSIGN_OR_RETURN(
190190
std::unique_ptr<LocalExecutable> executable,
191191
local_client_->Compile(computation, argument_layouts, build_options));
192-
TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
193-
194-
auto device_ordinal =
195-
build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
196-
auto* stream = run_options.stream();
197-
if (!stream) {
198-
stream = local_client_->mutable_backend()
199-
->BorrowStream(device_ordinal)
200-
.ValueOrDie()
201-
.get();
202-
}
203-
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
204-
return std::move(ret);
192+
return executable->Run(arguments, run_options);
205193
}
206194

207195
} // namespace xla

tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
168168
auto execution_result,
169169
executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg},
170170
&hlo_execution_profile));
171-
TF_ASSERT_OK(stream_ptr->BlockHostUntilDone());
172171
(void)execution_result;
173172

174173
*profile_output =

tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,5 +922,49 @@ def model_fn():
922922
self.assertEquals(4.5, self.evaluate(mirrored_var))
923923

924924

925+
class MirroredAndTowerLocalVariableInitializerTest(test.TestCase):
926+
config = config_pb2.ConfigProto()
927+
config.allow_soft_placement = True
928+
929+
def testAssignMirroredVarInitializer(self):
930+
# This test is not eager compatible since in eager variables are initialized
931+
# upon construction instead of once the initialization op is run.
932+
with context.graph_mode():
933+
def var_fn():
934+
v = variable_scope.variable(1.0, name="foo")
935+
return v
936+
937+
dist = mirrored_strategy.MirroredStrategy(
938+
["/device:GPU:0", "/device:CPU:0"])
939+
940+
with dist.scope():
941+
mirrored_var = dist.call_for_each_tower(var_fn)
942+
self.assertIsInstance(mirrored_var, values.MirroredVariable)
943+
self.assertFalse(self.evaluate(mirrored_var.is_initialized()))
944+
self.evaluate(mirrored_var.initializer)
945+
self.assertTrue(self.evaluate(mirrored_var.is_initialized()))
946+
947+
def testAssignTowerLocalVarInitializer(self):
948+
# This test is not eager compatible since in eager variables are initialized
949+
# upon construction instead of once the initialization op is run.
950+
with context.graph_mode():
951+
def model_fn():
952+
tower_context = distribute_lib.get_tower_context()
953+
with tower_context.tower_local_var_scope(
954+
variable_scope.VariableAggregation.SUM):
955+
v_sum = variable_scope.variable(1.0)
956+
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
957+
return v_sum
958+
959+
dist = mirrored_strategy.MirroredStrategy(
960+
["/device:GPU:0", "/device:CPU:0"])
961+
962+
with dist.scope():
963+
tower_local_var = dist.call_for_each_tower(model_fn)
964+
self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable))
965+
self.assertFalse(self.evaluate(tower_local_var.is_initialized()))
966+
self.evaluate(tower_local_var.initializer)
967+
self.assertTrue(self.evaluate(tower_local_var.is_initialized()))
968+
925969
if __name__ == "__main__":
926970
test.main()

0 commit comments

Comments
 (0)