Skip to content

Commit df1cc0e

Browse files
salilsdesaipytorchmergebot
authored andcommitted
[Vulkan] Add Vulkan Rewrite to Transfer Inputs and Outputs to Vulkan and CPU Backends Respectively (pytorch#87432)
With this change, we don't have to manually invoke transferring input and output backends when we run vulkan models. Graph rewrite code based off of: - pytorch@32efff4#diff-a473bddb458dc24225866a45092d6eca064eddd256245d93020e48e216eee4d5R160-R179 Differential Revision: [D39519168](https://our.internmc.facebook.com/intern/diff/D39519168/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39519168/)! Pull Request resolved: pytorch#87432 Approved by: https://github.com/mcr229, https://github.com/digantdesai
1 parent bc68625 commit df1cc0e

File tree

9 files changed

+83
-10
lines changed

9 files changed

+83
-10
lines changed

android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
195195
std::vector<at::IValue> inputs{};
196196
size_t n = jinputs->size();
197197
inputs.reserve(n);
198+
const bool requires_backend_transfers =
199+
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
198200
for (size_t i = 0; i < n; i++) {
199201
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
200-
if (at::kVulkan == deviceType_) {
202+
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
201203
inputs.push_back(
202204
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
203205
: std::move(atIValue));
204206
} else {
205-
TORCH_CHECK(at::kCPU == deviceType_);
207+
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
206208
inputs.push_back(std::move(atIValue));
207209
}
208210
}
@@ -223,14 +225,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
223225
std::vector<at::IValue> inputs{};
224226
size_t n = jinputs->size();
225227
inputs.reserve(n);
228+
const bool requires_backend_transfers =
229+
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
226230
for (size_t i = 0; i < n; i++) {
227231
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
228-
if (at::kVulkan == deviceType_) {
232+
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
229233
inputs.push_back(
230234
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
231235
: std::move(atIValue));
232236
} else {
233-
TORCH_CHECK(at::kCPU == deviceType_);
237+
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
234238
inputs.push_back(std::move(atIValue));
235239
}
236240
}

android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
158158
std::vector<at::IValue> inputs{};
159159
size_t n = jinputs->size();
160160
inputs.reserve(n);
161+
const bool requires_backend_transfers =
162+
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
161163
for (const auto i : c10::irange(n)) {
162164
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
163-
if (at::kVulkan == deviceType_) {
165+
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
164166
inputs.push_back(
165167
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
166168
: std::move(atIValue));
167169
} else {
168-
TORCH_CHECK(at::kCPU == deviceType_);
170+
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
169171
inputs.push_back(std::move(atIValue));
170172
}
171173
}
@@ -187,14 +189,16 @@ class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
187189
std::vector<at::IValue> inputs{};
188190
size_t n = jinputs->size();
189191
inputs.reserve(n);
192+
const bool requires_backend_transfers =
193+
module_.attr("requires_backend_transfers", at::IValue(true)).toBool();
190194
for (const auto i : c10::irange(n)) {
191195
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
192-
if (at::kVulkan == deviceType_) {
196+
if (at::kVulkan == deviceType_ && requires_backend_transfers) {
193197
inputs.push_back(
194198
atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()}
195199
: std::move(atIValue));
196200
} else {
197-
TORCH_CHECK(at::kCPU == deviceType_);
201+
TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers);
198202
inputs.push_back(std::move(atIValue));
199203
}
200204
}

binaries/speed_benchmark_torch.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ class vkRunner final : public Runner<T> {
180180
virtual c10::IValue run(
181181
T& module,
182182
const std::vector<c10::IValue>& inputs) override {
183+
if (!module.attr("requires_backend_transfers", at::IValue(true)).toBool()) {
184+
// No need to transfer input/output backends
185+
return module.forward(inputs);
186+
}
183187

184188
if (inputs_.size() == 0) {
185189
// Upload the input tensor(s) to GPU memory.

docs/source/mobile_optimizer.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ torch.utils.mobile_optimizer
77
Torch mobile supports ``torch.mobile_optimizer.optimize_for_mobile`` utility to run a list of optimization pass with modules in eval mode.
88
The method takes the following parameters: a torch.jit.ScriptModule object, a blocklisting optimization set and a preserved method list
99

10-
By default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the following optimizations:
10+
For CPU Backend, by default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the following optimizations:
1111
- **Conv2D + BatchNorm fusion** (blocklisting option `MobileOptimizerType::CONV_BN_FUSION`): This optimization pass folds ``Conv2d-BatchNorm2d`` into ``Conv2d`` in ``forward`` method of this module and all its submodules. The weight and bias of the ``Conv2d`` are correspondingly updated.
1212
- **Insert and Fold prepacked ops** (blocklisting option `MobileOptimizerType::INSERT_FOLD_PREPACK_OPS`): This optimization pass rewrites the graph to replace 2D convolutions and linear ops with their prepacked counterparts. Prepacked ops are stateful ops in that, they require some state to be created, such as weight prepacking and use this state, i.e. prepacked weights, during op execution. XNNPACK is one such backend that provides prepacked ops, with kernels optimized for mobile platforms (such as ARM CPUs). Prepacking of weight enables efficient memory access and thus faster kernel execution. At the moment ``optimize_for_mobile`` pass rewrites the graph to replace ``Conv2D/Linear`` with 1) op that pre-packs weight for XNNPACK conv2d/linear ops and 2) op that takes pre-packed weight and activation as input and generates output activations. Since 1 needs to be done only once, we fold the weight pre-packing such that it is done only once at model load time. This pass of the ``optimize_for_mobile`` does 1 and 2 and then folds, i.e. removes, weight pre-packing ops.
1313
- **ReLU/Hardtanh fusion**: XNNPACK ops support fusion of clamping. That is clamping of output activation is done as part of the kernel, including for 2D convolution and linear op kernels. Thus clamping effectively comes for free. Thus any op that can be expressed as clamping op, such as ``ReLU`` or ``hardtanh``, can be fused with previous ``Conv2D`` or ``linear`` op in XNNPACK. This pass rewrites graph by finding ``ReLU/hardtanh`` ops that follow XNNPACK ``Conv2D/linear`` ops, written by the previous pass, and fuses them together.
1414
- **Dropout removal** (blocklisting option `MobileOptimizerType::REMOVE_DROPOUT`): This optimization pass removes ``dropout`` and ``dropout_`` nodes from this module when training is false.
1515
- **Conv packed params hoisting** (blocklisting option `MobileOptimizerType::HOIST_CONV_PACKED_PARAMS`): This optimization pass moves convolution packed params to the root module, so that the convolution structs can be deleted. This decreases model size without impacting numerics.
1616

17+
for Vulkan Backend, by default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the folllwing optimization:
18+
- **Automatic GPU Transfer** (blocklisting option `MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER`): This optimization pass rewrites the graph such that inputs are transferred to Vulkan backend, and outputs are transferred to CPU backend
19+
1720
``optimize_for_mobile`` will also invoke freeze_module pass which only preserves ``forward`` method. If you have other method to that needed to be preserved, add them into the preserved method list and pass into the method.
1821

1922

test/test_public_bindings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_no_new_bindings(self):
261261
"set_num_threads",
262262
"unify_type_list",
263263
"vitals_enabled",
264-
264+
"VULKAN_AUTOMATIC_GPU_TRANSFER",
265265
"wait",
266266
"Tag",
267267
}

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ INSERT_FOLD_PREPACK_OPS: MobileOptimizerType
178178
REMOVE_DROPOUT: MobileOptimizerType
179179
FUSE_ADD_RELU: MobileOptimizerType
180180
HOIST_CONV_PACKED_PARAMS: MobileOptimizerType
181+
VULKAN_AUTOMATIC_GPU_TRANSFER: MobileOptimizerType
181182

182183
def fork(*args: Any, **kwargs: Any) -> Future: ...
183184
def wait(fut: Future) -> Any: ...

torch/csrc/jit/passes/mobile_optimizer_type.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ enum class MobileOptimizerType : int8_t {
99
FUSE_ADD_RELU,
1010
HOIST_CONV_PACKED_PARAMS,
1111
CONV_1D_TO_2D,
12+
VULKAN_AUTOMATIC_GPU_TRANSFER,
1213
};

torch/csrc/jit/passes/vulkan_rewrite.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/csrc/jit/ir/ir.h>
33
#include <torch/csrc/jit/ir/subgraph_matcher.h>
44
#include <torch/csrc/jit/passes/constant_pooling.h>
5+
#include <torch/csrc/jit/passes/dead_code_elimination.h>
56
#include <torch/csrc/jit/passes/fold_conv_bn.h>
67
#include <torch/csrc/jit/passes/freeze_module.h>
78
#include <torch/csrc/jit/passes/fuse_linear.h>
@@ -82,6 +83,51 @@ void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
8283
transpose_rewriter.runOnGraph(graph);
8384
}
8485

86+
void transferInputOutputBackends(std::shared_ptr<Graph>& graph) {
87+
// Move inputs to Vulkan backend
88+
for (Value* input : graph->inputs()) {
89+
NamedValue named_input = NamedValue("", input);
90+
if (named_input.type()->kind() == TypeKind::TensorType) {
91+
// find the insertion point
92+
WithInsertPoint ip(input->uses()[0].user->prev());
93+
Value* replaced_input = graph->insert(
94+
Symbol::fromQualString("aten::to"), {named_input, "vulkan"});
95+
// replace the input
96+
input->replaceAllUsesAfterNodeWith(
97+
replaced_input->node(), replaced_input);
98+
}
99+
}
100+
101+
// Move outputs to CPU backend
102+
at::ArrayRef<Value*>&& outputs = graph->outputs();
103+
for (size_t i = 0; i < outputs.size(); i++) {
104+
Value* output = outputs[i];
105+
NamedValue named_output = NamedValue("", output);
106+
if (named_output.type()->kind() == TypeKind::TensorType) {
107+
// find the insertion point
108+
WithInsertPoint ip(output->node()->next());
109+
Value* replaced_output = graph->insert(
110+
Symbol::fromQualString("aten::to"), {named_output, "cpu"});
111+
// replace the output
112+
graph->block()->replaceOutput(i, replaced_output);
113+
}
114+
}
115+
116+
SubgraphRewriter rewriter;
117+
rewriter.runOnGraph(graph);
118+
}
119+
120+
void transferInputOutputBackends(script::Module& module) {
121+
std::shared_ptr<Graph> graph = module.get_methods()[0].graph();
122+
transferInputOutputBackends(graph);
123+
}
124+
125+
void eliminateDeadCode(script::Module& module) {
126+
for (auto& method : module.get_methods()) {
127+
EliminateDeadCode(method.graph());
128+
}
129+
}
130+
85131
void insertPrePackedGruOp(std::shared_ptr<Graph>& graph) {
86132
std::string gru_pattern = R"(
87133
graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
@@ -276,12 +322,19 @@ script::Module vulkanOptimizeForMobile(
276322
cloned_module = FoldConvBatchNorm(cloned_module);
277323
vulkanInsertPrePackedOps(cloned_module);
278324
cloned_module = freeze_module(cloned_module, preserved_methods);
325+
if (!optimization_blocklist.count(
326+
MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) {
327+
transferInputOutputBackends(cloned_module);
328+
cloned_module.register_attribute(
329+
"requires_backend_transfers", BoolType::get(), false);
330+
}
279331
vulkanFusePrePackedConvWithClamp(cloned_module);
280332
vulkanFoldPrePackingOps(cloned_module);
281333
removeDropout(cloned_module);
282334
vulkanRemoveMutation(cloned_module);
283335
// remove duplicated constants
284336
vulkanRunCanonicalOptimizations(cloned_module);
337+
eliminateDeadCode(cloned_module);
285338

286339
cloned_module.register_attribute(
287340
"optimized_for_vulkan", BoolType::get(), true);

torch/csrc/jit/python/init.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,6 +1297,9 @@ void initJITBindings(PyObject* module) {
12971297
.value(
12981298
"HOIST_CONV_PACKED_PARAMS",
12991299
MobileOptimizerType::HOIST_CONV_PACKED_PARAMS)
1300+
.value(
1301+
"VULKAN_AUTOMATIC_GPU_TRANSFER",
1302+
MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)
13001303
.export_values();
13011304

13021305
// This allows PyTorchStreamReader to read from a Python buffer. It requires

0 commit comments

Comments
 (0)