Skip to content

Commit bc68625

Browse files
salilsdesaipytorchmergebot
authored andcommitted
[Vulkan] Add support for Optimization Blocklist to Vulkan Rewrite (pytorch#87431)
Optimization Blocklist will be used in a future diff (D40315730) to make the rewrite to transfer input/output backends optional Differential Revision: [D40315729](https://our.internmc.facebook.com/intern/diff/D40315729/) Pull Request resolved: pytorch#87431 Approved by: https://github.com/mcr229, https://github.com/digantdesai
1 parent f717986 commit bc68625

9 files changed

+35
-19
lines changed

binaries/optimize_for_mobile.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616

1717
#include <string>
1818
#include <sstream>
19-
#include "torch/script.h"
20-
#include "torch/csrc/jit/api/module.h"
19+
#include <torch/script.h>
20+
#include <torch/csrc/jit/api/module.h>
2121
#include <torch/csrc/jit/passes/metal_rewrite.h>
22-
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
23-
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
24-
#include "torch/csrc/jit/serialization/import.h"
25-
#include "torch/csrc/jit/serialization/export.h"
22+
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
23+
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
24+
#include <torch/csrc/jit/serialization/import.h>
25+
#include <torch/csrc/jit/serialization/export.h>
2626

2727
C10_DEFINE_string(model, "", "The torch script model to optimize.");
2828
C10_DEFINE_string(
@@ -86,7 +86,8 @@ int main(int argc, char** argv) {
8686
if (FLAGS_backend == "" || FLAGS_backend == "cpu") {
8787
optimized_module = torch::jit::optimizeForMobile(module);
8888
} else if (FLAGS_backend == "vulkan") {
89-
optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods);
89+
optimized_module = torch::jit::vulkanOptimizeForMobile(
90+
module, std::set<MobileOptimizerType>(), preserved_methods);
9091
} else if (FLAGS_backend == "metal"){
9192
optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods);
9293
}else{

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class Future(object):
169169

170170
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
171171

172-
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
172+
# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
173173
class MobileOptimizerType:
174174
...
175175

@@ -215,6 +215,7 @@ def _clone_module_with_class(module: 'torch.jit.ScriptModule',
215215
ignored_methods: List[AnyStr],
216216
ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
217217
def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
218+
optimization_blocklist: Set[MobileOptimizerType],
218219
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
219220
def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
220221
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
5+
enum class MobileOptimizerType : int8_t {
6+
CONV_BN_FUSION,
7+
INSERT_FOLD_PREPACK_OPS,
8+
REMOVE_DROPOUT,
9+
FUSE_ADD_RELU,
10+
HOIST_CONV_PACKED_PARAMS,
11+
CONV_1D_TO_2D,
12+
};

torch/csrc/jit/passes/vulkan_rewrite.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ void vulkanRunCanonicalOptimizations(script::Module& module) {
269269

270270
script::Module vulkanOptimizeForMobile(
271271
const script::Module& m,
272+
const std::set<MobileOptimizerType>& optimization_blocklist,
272273
const std::vector<std::string>& preserved_methods) {
273274
auto cloned_module = m.clone();
274275
cloned_module.eval();

torch/csrc/jit/passes/vulkan_rewrite.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <torch/csrc/jit/api/module.h>
44
#include <torch/csrc/jit/ir/ir.h>
5+
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
56

67
namespace torch {
78
namespace jit {
@@ -11,6 +12,7 @@ TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module);
1112
TORCH_API void vulkanFoldPrePackingOps(script::Module& module);
1213
TORCH_API script::Module vulkanOptimizeForMobile(
1314
const script::Module& module,
15+
const std::set<MobileOptimizerType>& optimization_blocklist,
1416
const std::vector<std::string>& preserved_methods);
1517
} // namespace jit
1618
} // namespace torch

torch/csrc/jit/passes/xnnpack_rewrite.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
1313
#include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
1414
#include <torch/csrc/jit/passes/inliner.h>
15+
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
1516
#include <torch/csrc/jit/passes/prepack_folding.h>
1617
#include <torch/csrc/jit/passes/remove_dropout.h>
1718
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

torch/csrc/jit/passes/xnnpack_rewrite.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,11 @@
22

33
#include <torch/csrc/jit/api/module.h>
44
#include <torch/csrc/jit/ir/ir.h>
5+
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
56

67
namespace torch {
78
namespace jit {
89

9-
enum class MobileOptimizerType : int8_t {
10-
CONV_BN_FUSION,
11-
INSERT_FOLD_PREPACK_OPS,
12-
REMOVE_DROPOUT,
13-
FUSE_ADD_RELU,
14-
HOIST_CONV_PACKED_PARAMS,
15-
CONV_1D_TO_2D,
16-
};
17-
1810
TORCH_API void transformConv1dToConv2d(std::shared_ptr<Graph>& graph);
1911
TORCH_API void transformConv1dToConv2d(script::Module& module);
2012
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);

torch/csrc/jit/python/init.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include <torch/csrc/jit/passes/lower_graph.h>
5353
#include <torch/csrc/jit/passes/lower_tuples.h>
5454
#include <torch/csrc/jit/passes/metal_rewrite.h>
55+
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
5556
#include <torch/csrc/jit/passes/normalize_ops.h>
5657
#include <torch/csrc/jit/passes/peephole.h>
5758
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
@@ -1081,8 +1082,10 @@ void initJITBindings(PyObject* module) {
10811082
.def(
10821083
"_jit_pass_vulkan_optimize_for_mobile",
10831084
[](script::Module& module,
1085+
std::set<MobileOptimizerType>& optimization_blocklist,
10841086
std::vector<std::string>& preserved_methods) {
1085-
return vulkanOptimizeForMobile(module, preserved_methods);
1087+
return vulkanOptimizeForMobile(
1088+
module, optimization_blocklist, preserved_methods);
10861089
})
10871090
.def(
10881091
"_jit_pass_metal_insert_prepacked_ops",

torch/utils/mobile_optimizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def optimize_for_mobile(
6464
optimization_blocklist,
6565
preserved_methods_str)
6666
elif backend == 'vulkan':
67-
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str)
67+
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(
68+
script_module._c,
69+
optimization_blocklist,
70+
preserved_methods_str)
6871
elif backend == 'metal':
6972
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str)
7073
else:

0 commit comments

Comments
 (0)