1
1
#include " torch/csrc/jit/passes/batch_mm.h"
2
2
3
3
#include " torch/csrc/jit/passes/dead_code_elimination.h"
4
+ #include " torch/csrc/jit/passes/peephole.h"
4
5
#include " torch/csrc/jit/interned_strings.h"
5
6
#include " torch/csrc/jit/constants.h"
6
7
#include " torch/csrc/jit/symbolic_variable.h"
7
8
#include " torch/csrc/jit/assertions.h"
9
+ #include " torch/csrc/jit/custom_operator.h"
8
10
#include " torch/csrc/utils/functional.h"
9
11
10
12
#include < ATen/ATen.h>
@@ -68,52 +70,98 @@ namespace torch { namespace jit {
68
70
// the trees we formed and fuse them.
69
71
70
72
// Tunable parameter. Set to something larger if it turns out to be better.
71
- static constexpr size_t min_fusion_size = 2 ;
73
+ static constexpr size_t min_fusion_size = 4 ;
72
74
73
- static std::array<int64_t , 2 > as_array (at::IntList sizes) {
74
- JIT_ASSERT (sizes.size () == 2 );
75
- std::array<int64_t , 2 > arr = {sizes[0 ], sizes[1 ]};
76
- return arr;
75
+ bool have_same_shape (at::TensorList inputs) {
76
+ auto expected_sizes = inputs[0 ].sizes ();
77
+ return std::all_of (inputs.begin (), inputs.end (),
78
+ [expected_sizes](const at::Tensor& t) {
79
+ return t.sizes () == expected_sizes;
80
+ });
77
81
}
78
82
83
+ bool shape_is_fast (const at::Tensor& lhs, const at::Tensor& rhs) {
84
+ size_t l = lhs.size (0 );
85
+ size_t m = lhs.size (1 );
86
+ size_t r = rhs.size (1 );
87
+ // Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V
88
+ return m < 512 || ((l < 256 && r < 256 ) || (l > 256 && r > 256 ));
89
+ }
90
+
91
+ RegisterOperators mm_tree_reduction_reg ({
92
+ Operator (
93
+ Symbol::prim (" MMTreeReduce" ),
94
+ [](const Node* node) {
95
+ size_t num_inputs = node->inputs ().size ();
96
+ return [num_inputs](Stack& stack) {
97
+ std::vector<at::Tensor> inputs;
98
+ inputs.reserve (num_inputs);
99
+ for (auto it = stack.end () - num_inputs; it != stack.end (); ++it) {
100
+ inputs.push_back (std::move (*it).toTensor ());
101
+ }
102
+ drop (stack, num_inputs);
103
+
104
+ JIT_ASSERT (inputs.size () > 0 );
105
+ JIT_ASSERT (inputs.size () % 2 == 0 );
106
+ size_t side_num_elems = inputs.size () / 2 ;
107
+ auto lhs_inputs = at::TensorList (inputs).slice (0 , side_num_elems);
108
+ auto rhs_inputs = at::TensorList (inputs).slice (side_num_elems);
109
+ // TODO: checking this is not free, so we should stop if this keeps failing
110
+ // TODO: benchmark to find when is this really a win, and add size constraints
111
+ if (have_same_shape (lhs_inputs) && have_same_shape (rhs_inputs) && shape_is_fast (lhs_inputs[0 ], rhs_inputs[0 ])) {
112
+ auto lhs = at::cat (lhs_inputs, /* dim=*/ 1 );
113
+ auto rhs = at::cat (rhs_inputs, /* dim=*/ 0 );
114
+ push (stack, at::mm (lhs, rhs));
115
+ } else {
116
+ auto acc = at::mm (inputs[0 ], inputs[side_num_elems]);
117
+ for (size_t i = 1 ; i < side_num_elems; ++i) {
118
+ acc.add_ (at::mm (inputs[i], inputs[side_num_elems + i]));
119
+ }
120
+ push (stack, std::move (acc));
121
+ }
122
+ return 0 ;
123
+ };
124
+ })
125
+ });
126
+
79
127
// TreeTokens will be used to label nodes of the graph, if the nodes will fit
80
128
// our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
81
129
// when we reach node N with inputs A and B, then A and B have already been
82
130
// procesed, and we can try to unify their TreeTokens (if they have them)
83
131
// and build a larger tree.
84
132
struct TreeToken {
85
133
uint64_t tree_size = 0 ; // NOTE: measured in number of leaves i.e. mm ops
86
- std::array<int64_t , 2 > lhs_sizes{{0 , 0 }};
87
- std::array<int64_t , 2 > rhs_sizes{{0 , 0 }};
88
134
Node *node = nullptr ;
89
135
bool is_root = false ;
90
136
91
- static TreeToken fromMM (Node *mm) {
137
+ static TreeToken mm (Node *mm) {
92
138
TreeToken token;
93
139
token.tree_size = 1 ;
94
- Value *lhs = mm->inputs ()[0 ];
95
- Value *rhs = mm->inputs ()[1 ];
96
- token.lhs_sizes = as_array (lhs->type ()->expect <CompleteTensorType>()->sizes ());
97
- token.rhs_sizes = as_array (rhs->type ()->expect <CompleteTensorType>()->sizes ());
98
140
token.node = mm;
99
141
token.is_root = true ;
100
142
return token;
101
143
}
102
144
103
- static TreeToken unify (Node *add, TreeToken& l, TreeToken& r) {
145
+ // NB: the returned token might be invalid, so make sure to check its boolean value!
146
+ static TreeToken transpose (Node *t, TreeToken& inp_token) {
147
+ TreeToken token;
148
+ if (!inp_token.node ->matches (" aten::mm(Tensor self, Tensor mat2) -> Tensor" )) {
149
+ return token;
150
+ }
151
+ token.tree_size = 1 ;
152
+ token.node = t;
153
+ token.is_root = true ;
154
+ inp_token.is_root = false ;
155
+ return token;
156
+ }
157
+
158
+ // NB: the returned token might be invalid, so make sure to check its boolean value!
159
+ static TreeToken add (Node *add, TreeToken& l, TreeToken& r) {
104
160
TreeToken token;
105
161
// See Note [Overlapping trees]
106
162
if (&l == &r || !l.is_root || !r.is_root )
107
163
return token;
108
- // We can batch the tree only if all sizes match, because we need to
109
- // cat inputs for both operands
110
- if (l.lhs_sizes != r.lhs_sizes )
111
- return token;
112
- if (l.rhs_sizes != r.rhs_sizes )
113
- return token;
114
164
token.tree_size = l.tree_size + r.tree_size ;
115
- token.lhs_sizes = l.lhs_sizes ;
116
- token.rhs_sizes = l.rhs_sizes ;
117
165
token.node = add;
118
166
token.is_root = true ;
119
167
l.is_root = r.is_root = false ; // Reserve the subtrees, so they can't be used again.
@@ -124,16 +172,31 @@ struct TreeToken {
124
172
return is_root;
125
173
}
126
174
127
- std::vector<Node*> gatherMatMuls () {
175
+ std::vector<Node*> removeTransposesAndGatherMatmuls () {
128
176
std::vector<Node*> matmuls;
129
177
std::vector<Node*> queue {node};
178
+ Graph* graph = node->owningGraph ();
130
179
while (!queue.empty ()) {
131
180
auto n = queue.back (); queue.pop_back ();
132
- if (n->kind () == aten::mm) {
181
+ if (n->matches ( " aten::mm(Tensor self, Tensor mat2) -> Tensor " ) ) {
133
182
matmuls.push_back (n);
134
- } else {
183
+ } else if (n->matches (" aten::t(Tensor self) -> Tensor" )) {
184
+ Node * input_node = n->input ()->node ();
185
+ JIT_ASSERT (input_node->matches (" aten::mm(Tensor self, Tensor mat2) -> Tensor" ));
186
+ // (AB)^T == B^TA^T
187
+ WithInsertPoint insert_guard { input_node };
188
+ Value * A = input_node->inputs ()[0 ];
189
+ Value * B = input_node->inputs ()[1 ];
190
+ Value * AT = graph->insert (aten::t, {A});
191
+ Value * BT = graph->insert (aten::t, {B});
192
+ Value * BTAT = graph->insert (aten::mm, {BT, AT});
193
+ n->output ()->replaceAllUsesWith (BTAT);
194
+ matmuls.push_back (BTAT->node ());
195
+ } else if (n->matches (" aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" )) {
135
196
queue.push_back (n->inputs ()[0 ]->node ());
136
197
queue.push_back (n->inputs ()[1 ]->node ());
198
+ } else {
199
+ AT_ASSERTM (false , " Unsupported node found in a BatchMM tree!" );
137
200
}
138
201
}
139
202
return matmuls;
@@ -147,13 +210,14 @@ void BatchMMBlock(Block* block) {
147
210
// Look for trees in the block
148
211
std::unordered_map<Node*, TreeToken> tokens;
149
212
for (auto node : block->nodes ()) {
150
- if (node->kind () == aten::mm &&
151
- node->input (0 )->type ()->cast <CompleteTensorType>() &&
152
- node->input (1 )->type ()->cast <CompleteTensorType>()) {
153
- tokens[node] = TreeToken::fromMM (node);
154
- } else if (node->kind () == aten::add) {
155
- // NOTE: x + 2 is add[other={2}](%x)
156
- if (node->inputs ().size () != 2 ) continue ;
213
+ if (node->matches (" aten::mm(Tensor self, Tensor mat2) -> Tensor" )) {
214
+ tokens[node] = TreeToken::mm (node);
215
+ } else if (node->matches (" aten::t(Tensor self) -> Tensor" )) {
216
+ auto input_it = tokens.find (node->input ()->node ());
217
+ if (input_it != tokens.end ()) {
218
+ tokens[node] = TreeToken::transpose (node, input_it->second );
219
+ }
220
+ } else if (node->matches (" aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" )) {
157
221
Node *lhs = node->inputs ()[0 ]->node ();
158
222
Node *rhs = node->inputs ()[1 ]->node ();
159
223
auto lhs_it = tokens.find (lhs);
@@ -166,8 +230,9 @@ void BatchMMBlock(Block* block) {
166
230
// we need to compute a transitive closure and actually check the dependencies.
167
231
if (lhs_it != tokens.end () && rhs_it != tokens.end () &&
168
232
lhs->output ()->uses ().size () == 1 && rhs->output ()->uses ().size () == 1 ) {
169
- if (auto token = TreeToken::unify (node, lhs_it->second , rhs_it->second ))
233
+ if (auto token = TreeToken::add (node, lhs_it->second , rhs_it->second )) {
170
234
tokens[node] = token;
235
+ }
171
236
}
172
237
} else {
173
238
for (auto block : node->blocks ()) {
@@ -181,35 +246,26 @@ void BatchMMBlock(Block* block) {
181
246
auto & root = item.second ;
182
247
if (!root || root.tree_size < min_fusion_size)
183
248
continue ;
184
- auto matmuls = root.gatherMatMuls ();
185
- auto type_ = root.node ->output ()->type ();
186
- auto type = type_->expect <CompleteTensorType>();
187
-
188
- auto batch_inputs = [&](Side s, std::array<int64_t , 2 > cat_sizes) -> Value* {
189
- int inputs_off = s == Side::LHS ? 0 : 1 ;
190
- int cat_dim = s == Side::LHS ? 1 : 0 ;
191
- cat_sizes[cat_dim] *= matmuls.size (); // make them really cat_sizes
192
-
193
- WithInsertPoint iguard { root.node };
194
- auto inputs = fmap (matmuls, [=](Node *mm) -> SymbolicVariable { return mm->inputs ()[inputs_off]; });
195
- auto cat_output = SymbolicVariable::cat (inputs, cat_dim).value ();
196
- cat_output->setType (type->withSizes (cat_sizes));
197
- return cat_output;
198
- };
199
-
200
- auto lhs_batch = batch_inputs (Side::LHS, root.lhs_sizes );
201
- auto rhs_batch = batch_inputs (Side::RHS, root.rhs_sizes );
202
- Node *batch_mm = graph->create (aten::mm, {lhs_batch, rhs_batch});
203
- batch_mm->output ()->setType (type_);
204
- batch_mm->insertBefore (root.node );
205
- root.node ->output ()->replaceAllUsesWith (batch_mm->output ());
249
+ auto matmuls = root.removeTransposesAndGatherMatmuls ();
250
+ WithInsertPoint insert_guard {root.node };
251
+ Node * tree_reduce = graph->insertNode (graph->create (Symbol::prim (" MMTreeReduce" )));
252
+ for (Node * matmul : matmuls) {
253
+ tree_reduce->addInput (matmul->inputs ().at (0 ));
254
+ }
255
+ for (Node * matmul : matmuls) {
256
+ tree_reduce->addInput (matmul->inputs ().at (1 ));
257
+ }
258
+ root.node ->output ()->replaceAllUsesWith (tree_reduce->output ());
206
259
// NB: don't bother with cleaning up after yourself. We'll use DCE for that.
207
260
}
208
- EliminateDeadCode (block);
209
261
}
210
262
211
263
void BatchMM (std::shared_ptr<Graph>& graph) {
212
264
BatchMMBlock (graph->block ());
265
+ EliminateDeadCode (graph);
266
+ // It's possible that transpose rearrangements have created sequences of consecutive
267
+ // transposes that didn't exist before.
268
+ PeepholeOptimize (graph);
213
269
}
214
270
215
271
}}
0 commit comments