Skip to content

Commit f87f1d0

Browse files
Mike Iovinefacebook-github-bot
Mike Iovine
authored andcommitted
[SR] assignStorageToManagedTensors returns a vector (pytorch#69568)
Summary: Pull Request resolved: pytorch#69568 Non-empty vectors should never be passed to `assignStorageToManagedTensors` and `assignStorageToManagedOutputTensors`. Presumably, this out-variant convention was adopted to avoid move-assigning the corresponding attribtues in `MemoryPlanner`. But the cost of a vector move-assign is not high, and this function type signature is safer. Test Plan: `buck test caffe2/bechmarks/static_runtime:static_runtime_cpptest` Reviewed By: donaldong Differential Revision: D32729289 fbshipit-source-id: 88f19de8eb89d8a4f1dd8bbd4d9e7f686e41888b
1 parent 9aa1b3e commit f87f1d0

File tree

3 files changed

+16
-19
lines changed

3 files changed

+16
-19
lines changed

benchmarks/static_runtime/test_static_module.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,9 +1230,8 @@ void testAssignStorageToManagedTensors(
12301230
ASSERT_EQ(managed_tensor_values.size(), tensor_value_to_tensor.size());
12311231

12321232
auto ranges = ManagedTensorRanges(graph, managed_tensor_values);
1233-
std::vector<StorageGroup> groups;
1234-
assignStorageToManagedTensors(
1235-
graph->block()->nodes(), ranges, tensor_value_to_tensor, groups);
1233+
auto groups = assignStorageToManagedTensors(
1234+
graph->block()->nodes(), ranges, tensor_value_to_tensor);
12361235

12371236
checkStorageGroups(
12381237
groups, ranges, tensor_value_to_tensor, min_reused_tensors);

torch/csrc/jit/runtime/static/memory_planner.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ FastMap<const Value*, at::Tensor*> tensorValueToTensor(
5656

5757
} // namespace
5858

59-
void assignStorageToManagedTensors(
59+
std::vector<StorageGroup> assignStorageToManagedTensors(
6060
graph_node_list nodes,
6161
const ManagedTensorRanges& ranges,
62-
const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor,
63-
std::vector<StorageGroup>& managed_tensor_groups) {
62+
const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor) {
63+
std::vector<StorageGroup> managed_tensor_groups;
6464
// This set maps each Value* to its assigned storage group.
6565
FastMap<const Value*, size_t> storage_group_mapping;
6666
// On each iteration, this vector stores the set of storage groups that
@@ -119,6 +119,7 @@ void assignStorageToManagedTensors(
119119
}
120120
}
121121
}
122+
return managed_tensor_groups;
122123
}
123124

124125
namespace {
@@ -127,10 +128,10 @@ bool setIncludes(const FastSet<const Value*>& set, const Value* v) {
127128
return set.find(v) != set.end();
128129
}
129130

130-
void assignStorageToOutputTensors(
131+
std::vector<std::pair<size_t, at::Tensor*>> assignStorageToOutputTensors(
131132
StaticRuntime* runtime,
132-
const FastSet<const Value*>& managed_output_tensor_values,
133-
std::vector<std::pair<size_t, at::Tensor*>>& managed_output_tensors) {
133+
const FastSet<const Value*>& managed_output_tensor_values) {
134+
std::vector<std::pair<size_t, at::Tensor*>> managed_output_tensors;
134135
for (auto& pnode : runtime->nodes()) {
135136
for (const auto i : c10::irange(pnode.outputs().size())) {
136137
auto& ival = pnode.Output(i);
@@ -144,6 +145,7 @@ void assignStorageToOutputTensors(
144145
managed_output_tensors.emplace_back(0, tensor);
145146
}
146147
}
148+
return managed_output_tensors;
147149
}
148150

149151
} // namespace
@@ -213,11 +215,8 @@ MemoryPlanner::MemoryPlanner(
213215
const auto tensor_value_to_tensor =
214216
tensorValueToTensor(runtime->nodes(), managed_tensor_values);
215217
if (optimize_memory) {
216-
::torch::jit::assignStorageToManagedTensors(
217-
runtime->node_ptrs(),
218-
ranges,
219-
tensor_value_to_tensor,
220-
managed_tensors_);
218+
managed_tensors_ = assignStorageToManagedTensors(
219+
runtime->node_ptrs(), ranges, tensor_value_to_tensor);
221220
} else {
222221
for (auto& tensor : tensor_value_to_tensor) {
223222
managed_tensors_.emplace_back(tensor.second);
@@ -226,8 +225,8 @@ MemoryPlanner::MemoryPlanner(
226225
}
227226

228227
if (enable_out_variant && manage_output_tensors) {
229-
::torch::jit::assignStorageToOutputTensors(
230-
runtime, managed_output_tensor_values, managed_output_tensors_);
228+
managed_output_tensors_ =
229+
assignStorageToOutputTensors(runtime, managed_output_tensor_values);
231230
}
232231

233232
num_managed_tensors_ = 0;

torch/csrc/jit/runtime/static/memory_planner.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,10 @@ class StorageGroup {
3939
std::vector<at::Tensor*> group_{};
4040
};
4141

42-
TORCH_API void assignStorageToManagedTensors(
42+
TORCH_API std::vector<StorageGroup> assignStorageToManagedTensors(
4343
graph_node_list nodes,
4444
const ManagedTensorRanges& ranges,
45-
const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor,
46-
std::vector<StorageGroup>& managed_tensor_groups);
45+
const FastMap<const Value*, at::Tensor*>& tensor_value_to_tensor);
4746

4847
// There are three types of ops in a processed graph in Static Runtime:
4948
// 1. op with _out variant

0 commit comments

Comments
 (0)