Skip to content

Commit 26a8bb6

Browse files
apaszkefacebook-github-bot
authored andcommitted
Re-enabled mm+add tree batching in the JIT (pytorch#13228)
Summary: I've had to generously increase the range of the CreateADSubgraphs pass, because even though it collapses the RNN loop to a single differentiable subgraphs and a few other nodes, the range uses the distances in the original graph... cc zdevito zou3519 Pull Request resolved: pytorch#13228 Differential Revision: D12871316 Pulled By: zou3519 fbshipit-source-id: 32da6f30f7821e4339034f1a4dec41ed0849abfb
1 parent 81438f1 commit 26a8bb6

File tree

6 files changed

+127
-64
lines changed

6 files changed

+127
-64
lines changed

torch/csrc/jit/init.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ void initJITBindings(PyObject *module) {
101101
return EliminateCommonSubexpression(g); // overload resolution
102102
})
103103
.def("_jit_pass_constant_pooling", ConstantPooling)
104-
.def("_jit_pass_peephole", PeepholeOptimize, py::arg("graph"), py::arg("addmm_fusion_enabled") = false)
104+
.def("_jit_pass_peephole", [](const std::shared_ptr<Graph>& g, bool addmm_fusion_enabled) {
105+
return PeepholeOptimize(g, addmm_fusion_enabled);
106+
}, py::arg("graph"), py::arg("addmm_fusion_enabled") = false)
105107
.def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
106108
return Canonicalize(g);
107109
})

torch/csrc/jit/passes/batch_mm.cpp

+111-55
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#include "torch/csrc/jit/passes/batch_mm.h"
22

33
#include "torch/csrc/jit/passes/dead_code_elimination.h"
4+
#include "torch/csrc/jit/passes/peephole.h"
45
#include "torch/csrc/jit/interned_strings.h"
56
#include "torch/csrc/jit/constants.h"
67
#include "torch/csrc/jit/symbolic_variable.h"
78
#include "torch/csrc/jit/assertions.h"
9+
#include "torch/csrc/jit/custom_operator.h"
810
#include "torch/csrc/utils/functional.h"
911

1012
#include <ATen/ATen.h>
@@ -68,52 +70,98 @@ namespace torch { namespace jit {
6870
// the trees we formed and fuse them.
6971

7072
// Tunable parameter. Set to something larger if it turns out to be better.
71-
static constexpr size_t min_fusion_size = 2;
73+
static constexpr size_t min_fusion_size = 4;
7274

73-
static std::array<int64_t, 2> as_array(at::IntList sizes) {
74-
JIT_ASSERT(sizes.size() == 2);
75-
std::array<int64_t, 2> arr = {sizes[0], sizes[1]};
76-
return arr;
75+
bool have_same_shape(at::TensorList inputs) {
76+
auto expected_sizes = inputs[0].sizes();
77+
return std::all_of(inputs.begin(), inputs.end(),
78+
[expected_sizes](const at::Tensor& t) {
79+
return t.sizes() == expected_sizes;
80+
});
7781
}
7882

83+
bool shape_is_fast(const at::Tensor& lhs, const at::Tensor& rhs) {
84+
size_t l = lhs.size(0);
85+
size_t m = lhs.size(1);
86+
size_t r = rhs.size(1);
87+
// Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V
88+
return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
89+
}
90+
91+
RegisterOperators mm_tree_reduction_reg({
92+
Operator(
93+
Symbol::prim("MMTreeReduce"),
94+
[](const Node* node) {
95+
size_t num_inputs = node->inputs().size();
96+
return [num_inputs](Stack& stack) {
97+
std::vector<at::Tensor> inputs;
98+
inputs.reserve(num_inputs);
99+
for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
100+
inputs.push_back(std::move(*it).toTensor());
101+
}
102+
drop(stack, num_inputs);
103+
104+
JIT_ASSERT(inputs.size() > 0);
105+
JIT_ASSERT(inputs.size() % 2 == 0);
106+
size_t side_num_elems = inputs.size() / 2;
107+
auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
108+
auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
109+
// TODO: checking this is not free, so we should stop if this keeps failing
110+
// TODO: benchmark to find when is this really a win, and add size constraints
111+
if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) && shape_is_fast(lhs_inputs[0], rhs_inputs[0])) {
112+
auto lhs = at::cat(lhs_inputs, /*dim=*/1);
113+
auto rhs = at::cat(rhs_inputs, /*dim=*/0);
114+
push(stack, at::mm(lhs, rhs));
115+
} else {
116+
auto acc = at::mm(inputs[0], inputs[side_num_elems]);
117+
for (size_t i = 1; i < side_num_elems; ++i) {
118+
acc.add_(at::mm(inputs[i], inputs[side_num_elems + i]));
119+
}
120+
push(stack, std::move(acc));
121+
}
122+
return 0;
123+
};
124+
})
125+
});
126+
79127
// TreeTokens will be used to label nodes of the graph, if the nodes will fit
80128
// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
81129
// when we reach node N with inputs A and B, then A and B have already been
82130
// procesed, and we can try to unify their TreeTokens (if they have them)
83131
// and build a larger tree.
84132
struct TreeToken {
85133
uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
86-
std::array<int64_t, 2> lhs_sizes{{0, 0}};
87-
std::array<int64_t, 2> rhs_sizes{{0, 0}};
88134
Node *node = nullptr;
89135
bool is_root = false;
90136

91-
static TreeToken fromMM(Node *mm) {
137+
static TreeToken mm(Node *mm) {
92138
TreeToken token;
93139
token.tree_size = 1;
94-
Value *lhs = mm->inputs()[0];
95-
Value *rhs = mm->inputs()[1];
96-
token.lhs_sizes = as_array(lhs->type()->expect<CompleteTensorType>()->sizes());
97-
token.rhs_sizes = as_array(rhs->type()->expect<CompleteTensorType>()->sizes());
98140
token.node = mm;
99141
token.is_root = true;
100142
return token;
101143
}
102144

103-
static TreeToken unify(Node *add, TreeToken& l, TreeToken& r) {
145+
// NB: the returned token might be invalid, so make sure to check its boolean value!
146+
static TreeToken transpose(Node *t, TreeToken& inp_token) {
147+
TreeToken token;
148+
if (!inp_token.node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
149+
return token;
150+
}
151+
token.tree_size = 1;
152+
token.node = t;
153+
token.is_root = true;
154+
inp_token.is_root = false;
155+
return token;
156+
}
157+
158+
// NB: the returned token might be invalid, so make sure to check its boolean value!
159+
static TreeToken add(Node *add, TreeToken& l, TreeToken& r) {
104160
TreeToken token;
105161
// See Note [Overlapping trees]
106162
if (&l == &r || !l.is_root || !r.is_root)
107163
return token;
108-
// We can batch the tree only if all sizes match, because we need to
109-
// cat inputs for both operands
110-
if (l.lhs_sizes != r.lhs_sizes)
111-
return token;
112-
if (l.rhs_sizes != r.rhs_sizes)
113-
return token;
114164
token.tree_size = l.tree_size + r.tree_size;
115-
token.lhs_sizes = l.lhs_sizes;
116-
token.rhs_sizes = l.rhs_sizes;
117165
token.node = add;
118166
token.is_root = true;
119167
l.is_root = r.is_root = false; // Reserve the subtrees, so they can't be used again.
@@ -124,16 +172,31 @@ struct TreeToken {
124172
return is_root;
125173
}
126174

127-
std::vector<Node*> gatherMatMuls() {
175+
std::vector<Node*> removeTransposesAndGatherMatmuls() {
128176
std::vector<Node*> matmuls;
129177
std::vector<Node*> queue {node};
178+
Graph* graph = node->owningGraph();
130179
while (!queue.empty()) {
131180
auto n = queue.back(); queue.pop_back();
132-
if (n->kind() == aten::mm) {
181+
if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
133182
matmuls.push_back(n);
134-
} else {
183+
} else if (n->matches("aten::t(Tensor self) -> Tensor")) {
184+
Node * input_node = n->input()->node();
185+
JIT_ASSERT(input_node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor"));
186+
// (AB)^T == B^TA^T
187+
WithInsertPoint insert_guard { input_node };
188+
Value * A = input_node->inputs()[0];
189+
Value * B = input_node->inputs()[1];
190+
Value * AT = graph->insert(aten::t, {A});
191+
Value * BT = graph->insert(aten::t, {B});
192+
Value * BTAT = graph->insert(aten::mm, {BT, AT});
193+
n->output()->replaceAllUsesWith(BTAT);
194+
matmuls.push_back(BTAT->node());
195+
} else if (n->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
135196
queue.push_back(n->inputs()[0]->node());
136197
queue.push_back(n->inputs()[1]->node());
198+
} else {
199+
AT_ASSERTM(false, "Unsupported node found in a BatchMM tree!");
137200
}
138201
}
139202
return matmuls;
@@ -147,13 +210,14 @@ void BatchMMBlock(Block* block) {
147210
// Look for trees in the block
148211
std::unordered_map<Node*, TreeToken> tokens;
149212
for (auto node : block->nodes()) {
150-
if (node->kind() == aten::mm &&
151-
node->input(0)->type()->cast<CompleteTensorType>() &&
152-
node->input(1)->type()->cast<CompleteTensorType>()) {
153-
tokens[node] = TreeToken::fromMM(node);
154-
} else if (node->kind() == aten::add) {
155-
// NOTE: x + 2 is add[other={2}](%x)
156-
if (node->inputs().size() != 2) continue;
213+
if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
214+
tokens[node] = TreeToken::mm(node);
215+
} else if (node->matches("aten::t(Tensor self) -> Tensor")) {
216+
auto input_it = tokens.find(node->input()->node());
217+
if (input_it != tokens.end()) {
218+
tokens[node] = TreeToken::transpose(node, input_it->second);
219+
}
220+
} else if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
157221
Node *lhs = node->inputs()[0]->node();
158222
Node *rhs = node->inputs()[1]->node();
159223
auto lhs_it = tokens.find(lhs);
@@ -166,8 +230,9 @@ void BatchMMBlock(Block* block) {
166230
// we need to compute a transitive closure and actually check the dependencies.
167231
if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
168232
lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) {
169-
if (auto token = TreeToken::unify(node, lhs_it->second, rhs_it->second))
233+
if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
170234
tokens[node] = token;
235+
}
171236
}
172237
} else {
173238
for (auto block : node->blocks()) {
@@ -181,35 +246,26 @@ void BatchMMBlock(Block* block) {
181246
auto & root = item.second;
182247
if (!root || root.tree_size < min_fusion_size)
183248
continue;
184-
auto matmuls = root.gatherMatMuls();
185-
auto type_ = root.node->output()->type();
186-
auto type = type_->expect<CompleteTensorType>();
187-
188-
auto batch_inputs = [&](Side s, std::array<int64_t, 2> cat_sizes) -> Value* {
189-
int inputs_off = s == Side::LHS ? 0 : 1;
190-
int cat_dim = s == Side::LHS ? 1 : 0;
191-
cat_sizes[cat_dim] *= matmuls.size(); // make them really cat_sizes
192-
193-
WithInsertPoint iguard { root.node };
194-
auto inputs = fmap(matmuls, [=](Node *mm) -> SymbolicVariable { return mm->inputs()[inputs_off]; });
195-
auto cat_output = SymbolicVariable::cat(inputs, cat_dim).value();
196-
cat_output->setType(type->withSizes(cat_sizes));
197-
return cat_output;
198-
};
199-
200-
auto lhs_batch = batch_inputs(Side::LHS, root.lhs_sizes);
201-
auto rhs_batch = batch_inputs(Side::RHS, root.rhs_sizes);
202-
Node *batch_mm = graph->create(aten::mm, {lhs_batch, rhs_batch});
203-
batch_mm->output()->setType(type_);
204-
batch_mm->insertBefore(root.node);
205-
root.node->output()->replaceAllUsesWith(batch_mm->output());
249+
auto matmuls = root.removeTransposesAndGatherMatmuls();
250+
WithInsertPoint insert_guard {root.node};
251+
Node * tree_reduce = graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
252+
for (Node * matmul : matmuls) {
253+
tree_reduce->addInput(matmul->inputs().at(0));
254+
}
255+
for (Node * matmul : matmuls) {
256+
tree_reduce->addInput(matmul->inputs().at(1));
257+
}
258+
root.node->output()->replaceAllUsesWith(tree_reduce->output());
206259
// NB: don't bother with cleaning up after yourself. We'll use DCE for that.
207260
}
208-
EliminateDeadCode(block);
209261
}
210262

211263
void BatchMM(std::shared_ptr<Graph>& graph) {
212264
BatchMMBlock(graph->block());
265+
EliminateDeadCode(graph);
266+
// It's possible that transpose rearrangements have created sequences of consecutive
267+
// transposes that didn't exist before.
268+
PeepholeOptimize(graph);
213269
}
214270

215271
}}

torch/csrc/jit/passes/create_autodiff_subgraphs.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ static detail::DynamicDAG<Node*> make_dependency_graph(Block * block) {
139139

140140
static void find_differentiable_groups(
141141
detail::DynamicDAG<Node*>& dep_graph,
142-
size_t distance_threshold=64,
142+
size_t distance_threshold=256,
143143
size_t producer_edge_threshold=16) {
144144
// A Vertex contains a Node* or a differentiable group of Node*.
145145
// Perform graph contraction on dep_graph: contract two vertices(x, y) if

torch/csrc/jit/passes/loop_unrolling.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace torch { namespace jit {
1212
namespace {
1313

1414
static constexpr int64_t kUnrollFactor = 8;
15-
static constexpr int64_t kMaxBodySize = 16;
15+
static constexpr int64_t kMaxBodySize = 32;
1616
static constexpr int64_t kMaxBodyRepeats = 64;
1717

1818
bool isTrueConstant(Value *val) {

torch/csrc/jit/passes/peephole.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ namespace torch { namespace jit {
2121
// we would see redundant Gemm ops with sub-optimal inputs. This flag is exposed
2222
// so that ONNX export can pass `true` to get the fused behavior, but normal
2323
// JIT peephole optimization is left alone.
24-
void PeepholeOptimize(Block * block, bool addmm_fusion_enabled) {
24+
void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) {
2525
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
2626
auto* node = *it;
2727

2828
for (Block * sub_block : node->blocks()) {
29-
PeepholeOptimize(sub_block, addmm_fusion_enabled);
29+
PeepholeOptimizeImpl(sub_block, addmm_fusion_enabled);
3030
}
3131

3232
// XXX: remember that if you want to simplify an expression by combining multiple nodes
@@ -129,10 +129,14 @@ void PeepholeOptimize(Block * block, bool addmm_fusion_enabled) {
129129
}
130130
}
131131

132-
void PeepholeOptimize(std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled) {
133-
PeepholeOptimize(graph->block(), addmm_fusion_enabled);
132+
void PeepholeOptimize(Block* block, bool addmm_fusion_enabled) {
133+
PeepholeOptimizeImpl(block, addmm_fusion_enabled);
134134
// Eliminate dead code created by any peephole passes we've just done
135-
EliminateDeadCode(graph->block());
135+
EliminateDeadCode(block);
136+
}
137+
138+
void PeepholeOptimize(const std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled) {
139+
PeepholeOptimize(graph->block(), addmm_fusion_enabled);
136140
}
137141

138142
}}

torch/csrc/jit/passes/peephole.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
namespace torch { namespace jit {
66

7-
TORCH_API void PeepholeOptimize(std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled=false);
7+
TORCH_API void PeepholeOptimize(const std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled=false);
8+
TORCH_API void PeepholeOptimize(Block* block, bool addmm_fusion_enabled=false);
89

910
}}

0 commit comments

Comments
 (0)