Skip to content

Commit ba3a90b

Browse files
malfetfacebook-github-bot
authored andcommitted
Revert D28819780: [TensorExpr] Fix handling of 0-dim tensors.
Test Plan: revert-hammer Differential Revision: D28819780 Original commit changeset: f3feff35a1ce fbshipit-source-id: 1dca4ac9cea0b67e9f02800f6d5b3c7e4ae1d81a
1 parent 88fb5ee commit ba3a90b

File tree

8 files changed

+56
-129
lines changed

8 files changed

+56
-129
lines changed

test/cpp/tensorexpr/test_kernel.cpp

+16-13
Original file line numberDiff line numberDiff line change
@@ -675,24 +675,24 @@ at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) {
675675
} // namespace
676676

677677
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
678-
TEST_F(Kernel, SumAllAxes) {
678+
TEST_F(Kernel, DISABLED_SumAllAxes) {
679+
// [zero-dim tensors]
680+
// NNC does not yet handle zero-dim tensors. aten::sum with no axis
681+
// input returns a zero-dim tensors, so these tests must be disabled
682+
// until we add support for zero-dim tensors.
683+
679684
// Test lowering of sum on all axes.
680685
const auto graph_template = R"IR(
681686
graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
682687
%1 : ${dtype}
683-
%2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1)
688+
%2 : Tensor = aten::sum(%0, %1)
684689
return (%2))IR";
685690
auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
686691

687692
for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
688693
KernelScope kernel_scope;
689694
TemplateEnv env;
690695
env.s("dtype", dtypeConstant(scalar_type));
691-
if (scalar_type == ScalarType::Undefined) {
692-
env.s("out_dtype", "Float");
693-
} else {
694-
env.s("out_dtype", "Double");
695-
}
696696
const auto graph_string = format(graph_template, env);
697697

698698
auto graph = std::make_shared<Graph>();
@@ -1104,16 +1104,17 @@ TEST_F(Kernel, Softmax4D) {
11041104
}
11051105

11061106
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
1107-
TEST_F(Kernel, InlineProducerIntoReduction) {
1107+
TEST_F(Kernel, DISABLED_InlineProducerIntoReduction) {
1108+
// see : [zero-dim tensors]
11081109
KernelScope kernel_scope;
11091110

11101111
// Inline producer (mul) into reduction (sum).
11111112
const auto graph_string = R"IR(
11121113
graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
11131114
%1 : Float(5, 3, strides=[3, 1], device=cpu)):
1114-
%2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
1115+
%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
11151116
%3 : int = prim::Constant[value=7]()
1116-
%4 : Double(device=cpu) = aten::sum(%2, %3)
1117+
%4 : Float(5, 3, strides=[3, 1]) = aten::sum(%2, %3)
11171118
return (%4))IR";
11181119
auto graph = std::make_shared<Graph>();
11191120
parseIR(graph_string, &*graph);
@@ -1144,7 +1145,9 @@ TEST_F(Kernel, InlineProducerIntoReduction) {
11441145
}
11451146

11461147
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
1147-
TEST_F(Kernel, InlineReductionIntoConsumer) {
1148+
TEST_F(Kernel, DISABLED_InlineReductionIntoConsumer) {
1149+
// see : [zero-dim tensors]
1150+
11481151
KernelScope kernel_scope;
11491152

11501153
// Inline producer (mul %2) into reduction (sum %4) but DO NOT
@@ -1154,8 +1157,8 @@ TEST_F(Kernel, InlineReductionIntoConsumer) {
11541157
%1 : Float(5, 3, strides=[3, 1], device=cpu)):
11551158
%2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
11561159
%3 : int = prim::Constant[value=6]()
1157-
%4 : Float(device=cpu) = aten::sum(%2, %3)
1158-
%5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4)
1160+
%4 : Float(5, 3, strides=[3, 1]) = aten::sum(%2, %3)
1161+
%5 : Float(5, 3, strides=[3, 1]) = aten::mul(%2, %4)
11591162
return (%5))IR";
11601163
auto graph = std::make_shared<Graph>();
11611164
parseIR(graph_string, &*graph);

test/cpp/tensorexpr/test_reductions.cpp

-48
Original file line numberDiff line numberDiff line change
@@ -23,54 +23,6 @@ namespace jit {
2323

2424
using namespace torch::jit::tensorexpr;
2525

26-
TEST(Reductions, ReduceSum0D_1) {
27-
KernelScope kernel_scope;
28-
const int M = 10;
29-
30-
Placeholder b(BufHandle("b", {M}, kFloat));
31-
std::vector<float> in(M);
32-
for (int j = 0; j < M; ++j) {
33-
in[j] = j;
34-
}
35-
36-
std::vector<float> out(M, -1.f);
37-
38-
Tensor* c = Reduce("sum", {{M, "m"}}, Sum(), b, {});
39-
LoopNest loop({c});
40-
loop.prepareForCodegen();
41-
Stmt* s = loop.root_stmt();
42-
s = IRSimplifier::simplify(s);
43-
44-
SimpleIREvaluator cg(s, {b, c});
45-
46-
cg.call({in, out});
47-
for (int i = 0; i < M; ++i) {
48-
ASSERT_EQ(out[i], in[i]);
49-
}
50-
}
51-
52-
TEST(Reductions, ReduceSum0D_2) {
53-
KernelScope kernel_scope;
54-
const int M = 10;
55-
56-
Placeholder b(BufHandle("b", {}, kFloat));
57-
std::vector<float> in(1);
58-
in[0] = 77.7;
59-
60-
std::vector<float> out(1, -1.f);
61-
62-
Tensor* c = Reduce("sum", {}, Sum(), b, {});
63-
LoopNest loop({c});
64-
loop.prepareForCodegen();
65-
Stmt* s = loop.root_stmt();
66-
s = IRSimplifier::simplify(s);
67-
68-
SimpleIREvaluator cg(s, {b, c});
69-
70-
cg.call({in, out});
71-
ASSERT_EQ(out[0], in[0]);
72-
}
73-
7426
// Sum an array to a single value.
7527
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
7628
TEST(Reductions, ReduceSum1D) {

test/cpp/tensorexpr/test_te_fuser_pass.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -114,22 +114,21 @@ TEST(TEFuserPass, FuserPass_3) {
114114

115115
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
116116
TEST(TEFuserPass, FuserPass_0DimInput) {
117-
WithCPUFuser cf;
118117
const auto graph_string = R"IR(
119-
graph(%x : Float(device=cpu),
120-
%y : Float(device=cpu)):
118+
graph(%x : Float(device=cuda),
119+
%y : Float(device=cuda)):
121120
%one : int = prim::Constant[value=1]()
122-
%a : Float(device=cpu) = aten::mul(%x, %y)
123-
%b : Float(device=cpu) = aten::add(%x, %a, %one)
121+
%a : Float(device=cuda) = aten::mul(%x, %y)
122+
%b : Float(device=cuda) = aten::add(%x, %a, %one)
124123
return (%b))IR";
125124
auto g = std::make_shared<Graph>();
126125
torch::jit::parseIR(graph_string, g.get());
127126

128127
g->lint();
129128
FuseTensorExprs(g);
130129

131-
// We should fuse 0-dim tensors too
132-
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
130+
// We should not fuse 0-dim tensors
131+
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
133132
}
134133

135134
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)

test/test_jit_fuser_te.py

+13
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,7 @@ def fn_test_diamond(x, y):
999999
assert cx.elapsed_value() == 1
10001000
self.assertEqual(out, x + y)
10011001

1002+
@unittest.skip("Reenable when TE will add support for 0-dim tensors")
10021003
def test_scalar(self):
10031004
def fn(x, y):
10041005
return 2 * x + y
@@ -1972,6 +1973,7 @@ def te_compile(self, device, dtype, op):
19721973
if op.name in skip_ops:
19731974
return
19741975
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1976+
is_compiling = False
19751977
for sample_input in sample_inputs_itr:
19761978
arg_values = [sample_input.input] + list(sample_input.args)
19771979
kwarg_values = sample_input.kwargs
@@ -2003,12 +2005,23 @@ def f({', '.join(param_names)}):
20032005
f.__module__ = 'test'
20042006
out = f(*param_values)
20052007

2008+
# NNC currently oftens segfault when asked to lower ops with 0-dim tensor outputs
2009+
if isinstance(out, torch.Tensor) and out.dim() == 0:
2010+
continue
2011+
else:
2012+
is_compiling = True
2013+
20062014
ts_g = torch.jit.trace(f, param_values)
20072015
kernel = torch._C._te.TensorExprKernel(ts_g.graph)
20082016
correct_val = f(*param_values)
20092017
self.assertEqual(kernel.run(tuple(param_values)), correct_val)
20102018
self.assertEqual(kernel.fallback(tuple(param_values)), correct_val)
20112019

2020+
# If all sample inputs have scalar output, we won't have tested it and
2021+
# we consider the op to be not working
2022+
if not is_compiling:
2023+
raise RuntimeError("Skipped all inputs")
2024+
20122025
@onlyCPU
20132026
@unittest.skipIf(not LLVM_ENABLED, "Compiles with TensorExprKernel")
20142027
@ops([op for op in op_db if get_name(op) in works_list], allowed_dtypes=(torch.float,))

test/test_tensorexpr.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def test(x, y):
529529
)
530530
self.assertLastGraphAllFused()
531531

532+
@unittest.skip("temporarily disable")
532533
def test_min_max_reduction(self):
533534
def test(x):
534535
return torch.min(x) + torch.max(x)
@@ -538,6 +539,7 @@ def test(x):
538539
np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
539540
self.assertLastGraphAllFused()
540541

542+
@unittest.skip("temporarily disable")
541543
def test_min_max_reduction2(self):
542544
def test(x):
543545
return x.min() + x.max()
@@ -557,13 +559,14 @@ def test(x):
557559
a.numpy(), axis=1) + np.amax(a.numpy(), axis=1))
558560
self.assertLastGraphAllFused()
559561

562+
@unittest.skip("temporarily disable")
560563
def test_min_max_reduction_dim1_2(self):
561564
def test(x):
562-
return torch.min(x * x, 1)
565+
return torch.min(x, 1)
563566

564567
traced = torch.jit.trace(test, (torch.zeros(16, 16)))
565568
a = 8.0 * torch.rand(16, 16)
566-
np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1))
569+
np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin(a.numpy(), axis=1))
567570
self.assertLastGraphAllFused()
568571

569572
def test_clamp(self):

torch/csrc/jit/passes/tensorexpr_fuser.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,9 @@ class TensorExprFuser {
847847
if (!v->isCompleteTensor()) {
848848
return false;
849849
}
850+
if (*v->type()->castRaw<TensorType>()->dim() == 0) {
851+
return false;
852+
}
850853
}
851854
return true;
852855
}

torch/csrc/jit/tensorexpr/kernel.cpp

+13-49
Original file line numberDiff line numberDiff line change
@@ -1407,39 +1407,30 @@ Tensor* computeSum(
14071407
// aten::sum takes the input tensor named self.
14081408
auto sizes = valueShape(inputs[0]);
14091409

1410-
size_t rank = sizes.size();
1410+
int rank = sizes.size();
14111411
if (inputs.size() > 2) {
1412-
if (auto emptyAxes = c10::get_if<BufList>(&inputs[1])) {
1413-
// If dim-array is an empty list, it will appear as BufList instead of
1414-
// IntList, and hence we need a special handling for it.
1415-
// In that case, we need to sum over all axes.
1416-
TORCH_INTERNAL_ASSERT(emptyAxes->empty());
1417-
axes.resize(rank);
1418-
std::iota(axes.begin(), axes.end(), 0);
1419-
} else if (rank > 0) {
1420-
auto nodeAxes = c10::get<IntList>(inputs[1]);
1421-
// Canonicalize axes: wrap around, sort and make unique.
1422-
for (auto axis : nodeAxes) {
1423-
axes.push_back(at::maybe_wrap_dim(axis, rank));
1424-
}
1425-
std::sort(axes.begin(), axes.end());
1426-
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
1412+
auto nodeAxes = c10::get<IntList>(inputs[1]);
1413+
// Canonicalize axes: wrap around, sort and make unique.
1414+
for (auto axis : nodeAxes) {
1415+
axes.push_back(at::maybe_wrap_dim(axis, rank));
14271416
}
1417+
std::sort(axes.begin(), axes.end());
1418+
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
14281419
keepdim = c10::get<bool>(inputs[2]);
14291420
} else {
1430-
axes.resize(rank);
1421+
axes.resize(sizes.size());
14311422
std::iota(axes.begin(), axes.end(), 0);
14321423
}
14331424
// Axes go into reduction dimensions.
14341425
std::vector<DimArg> reductionDims;
1435-
reductionDims.reserve(rank);
1426+
reductionDims.reserve(sizes.size());
14361427
for (size_t axis : axes) {
14371428
reductionDims.emplace_back(sizes[axis]);
14381429
}
14391430
std::vector<DimArg> outputDims;
14401431
// Output dimensions are the complement of axes. When keepdim is set, a
14411432
// one-sized dimension is inserted for each axis.
1442-
for (size_t dim = 0; dim < rank; ++dim) {
1433+
for (size_t dim = 0; dim < sizes.size(); ++dim) {
14431434
if (!std::count(axes.begin(), axes.end(), dim)) {
14441435
outputDims.emplace_back(sizes[dim]);
14451436
} else if (keepdim) {
@@ -2519,6 +2510,9 @@ Tensor* tensorexpr::computeOperandValue(
25192510
}
25202511
case aten::t: {
25212512
auto shape = valueShape(inputs[0]);
2513+
if (shape.size() == 1) {
2514+
return new Tensor(c10::get<BufHandle>(inputs[0]).node(), nullptr);
2515+
}
25222516
return computeOperandValue(
25232517
aten::transpose,
25242518
{inputs[0], (int64_t)1, (int64_t)0},
@@ -2527,17 +2521,6 @@ Tensor* tensorexpr::computeOperandValue(
25272521
}
25282522
case aten::transpose: {
25292523
auto A = c10::get<BufHandle>(inputs[0]);
2530-
// Trivial case of 0-dim and 1-dim tensors: transpose is just a copy
2531-
if (A.ndim() < 1) {
2532-
return Compute(
2533-
"aten_transpose",
2534-
c10::fmap<DimArg>(outputShape),
2535-
[&](std::vector<VarHandle> axes) {
2536-
TORCH_INTERNAL_ASSERT(axes.size() <= 1);
2537-
return A.load(axes);
2538-
});
2539-
}
2540-
// Usual case where transpose actually swaps dimensions
25412524
auto start_dim =
25422525
at::maybe_wrap_dim(c10::get<int64_t>(inputs[1]), A.ndim());
25432526
auto to_dim = at::maybe_wrap_dim(c10::get<int64_t>(inputs[2]), A.ndim());
@@ -2551,16 +2534,6 @@ Tensor* tensorexpr::computeOperandValue(
25512534
}
25522535
case aten::permute: {
25532536
auto A = c10::get<BufHandle>(inputs[0]);
2554-
// Trivial case of 0-dim tensors: just a copy of the input
2555-
if (A.ndim() == 0) {
2556-
return Compute(
2557-
"aten_permute",
2558-
c10::fmap<DimArg>(outputShape),
2559-
[&](const std::vector<VarHandle>& axes) {
2560-
std::vector<ExprHandle> empty_indices;
2561-
return A.load(empty_indices);
2562-
});
2563-
}
25642537
auto permute_dims = c10::get<IntList>(inputs[1]);
25652538
return Compute(
25662539
"aten_permute",
@@ -2590,15 +2563,6 @@ Tensor* tensorexpr::computeOperandValue(
25902563
case aten::reshape:
25912564
case aten::view: {
25922565
auto A = c10::get<BufHandle>(inputs[0]);
2593-
if (A.ndim() == 0) {
2594-
return Compute(
2595-
"aten_view",
2596-
c10::fmap<DimArg>(outputShape),
2597-
[&](const std::vector<VarHandle>& axes) {
2598-
std::vector<ExprHandle> empty_indices;
2599-
return A.load(empty_indices);
2600-
});
2601-
}
26022566
auto view_dims = c10::get<IntList>(inputs[1]);
26032567
return Compute(
26042568
"aten_reshape",

torch/csrc/jit/tensorexpr/tensor.h

-10
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,6 @@ Tensor* Reduce(
177177
std::vector<const Var*> reduce_vars;
178178
unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);
179179

180-
// If reduce_vars is empty, then it's not a reduction, but rather a simple
181-
// copy
182-
if (reduce_vars.empty()) {
183-
const Expr* body =
184-
Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(vars))
185-
.node();
186-
Buf* func_result = new Buf(func_name, dims, body->dtype());
187-
return new Tensor(func_result, vars, body);
188-
}
189-
190180
std::vector<const Var*> all_vars;
191181
all_vars.insert(all_vars.end(), vars.begin(), vars.end());
192182
all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());

0 commit comments

Comments
 (0)