Skip to content

Commit 05b6dc9

Browse files
qxy11facebook-github-bot
authored andcommitted
Fix BatchMatMul test and shape inference (pytorch#66733)
Summary: Pull Request resolved: pytorch#66733 Fix the test for BatchMatMul to compare glow/caffe2 outputs and fix its shape inference function since it made simplifying assumptions for broadcasting and failed on some of the shapes in the test. The previous inference was failing for any cases where the first n - 2 output dimensions of A x B was not simply that of whichever one of A or B had higher rank (ex. A: [2, 2, 2, 3, 4], B: [3, 1, 2, 2, 4, 5] we expect output dimensions [3, 2, 2, 2, 3, 5] rather than [3, 1, 2, 2, 3, 5]. Test Plan: ``` buck test glow/fb/test/numerics:test_operator_onnxifinnpi -- -r .*test_batch_matmul_manydims.* --env USE_INF_API=1 ``` Reviewed By: khabinov Differential Revision: D31701184 fbshipit-source-id: 31d0fb17409a399b90fb8042385e000ed81c3581
1 parent 9f782f8 commit 05b6dc9

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

caffe2/operators/batch_matmul_op.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,19 @@ vector<TensorShape> TensorInferenceForBatchMatMul(
7373
N = dims_B[ndims_B - 1];
7474
}
7575

76-
std::vector<int64_t> new_dims;
77-
if (ndims_A >= ndims_B) {
78-
new_dims.assign(dims_A.begin(), dims_A.end() - 2);
79-
} else {
80-
new_dims.assign(dims_B.begin(), dims_B.end() - 2);
76+
const int ndims = std::max(ndims_A, ndims_B);
77+
std::vector<int64_t> new_dims(ndims - 2);
78+
std::vector<int64_t> dims_A_broadcast(ndims - 2, 1);
79+
std::vector<int64_t> dims_B_broadcast(ndims - 2, 1);
80+
81+
std::copy_n(dims_A.begin(), ndims_A - 2, dims_A_broadcast.begin() + ndims - ndims_A);
82+
std::copy_n(dims_B.begin(), ndims_B - 2, dims_B_broadcast.begin() + ndims - ndims_B);
83+
for (int i = 0; i < ndims - 2; ++i) {
84+
if (!dims_A_broadcast[i] || !dims_B_broadcast[i]) {
85+
new_dims[i] = 0;
86+
} else {
87+
new_dims[i] = std::max(dims_A_broadcast[i], dims_B_broadcast[i]);
88+
}
8189
}
8290
if (!A_broadcasted) {
8391
new_dims.push_back(M);

0 commit comments

Comments
 (0)