Skip to content

Commit 8c7f4a0

Browse files
davidberard98facebook-github-bot
authored andcommitted
[tensorexpr] check for index out of bounds in ir_eval (pytorch#68858)
Summary: Pull Request resolved: pytorch#68858 when executing with ir_eval, check for index out of bounds. Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D32657881 Pulled By: davidberard98 fbshipit-source-id: 62dd0f85bb182b34e9c9f795ff761081290f6922
1 parent 76d282d commit 8c7f4a0

File tree

4 files changed

+179
-2
lines changed

4 files changed

+179
-2
lines changed

test/cpp/tensorexpr/test_expr.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,65 @@ TEST(Expr, DynamicShapeAdd) {
556556
testWithSize(37);
557557
}
558558

559+
TEST(Expr, OutOfBounds) {
560+
ExprHandle N(10);
561+
ExprHandle start(0);
562+
ExprHandle stop(15);
563+
VarHandle i("i", kInt);
564+
565+
BufHandle X("X", {N}, kInt);
566+
567+
auto body = Store::make(X, {i}, i);
568+
auto stmt = For::make(i, start, stop, body);
569+
570+
PaddedBuffer<int> data(20);
571+
572+
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
573+
}
574+
575+
TEST(Expr, OutOfBounds2d) {
576+
std::vector<std::pair<int, int>> size_options = {{10, 15}, {15, 10}};
577+
for (auto sizes : size_options) {
578+
ExprHandle N(sizes.first);
579+
ExprHandle M(sizes.second);
580+
ExprHandle start(0);
581+
ExprHandle stopInner(15);
582+
ExprHandle stopOuter(15);
583+
VarHandle i("i", kInt);
584+
VarHandle j("j", kInt);
585+
586+
BufHandle X("X", {N, M}, kInt);
587+
588+
auto body = Store::make(X, {i, j}, i);
589+
auto inner = For::make(j, start, stopInner, body);
590+
auto stmt = For::make(i, start, stopOuter, inner);
591+
592+
PaddedBuffer<int> data(400);
593+
594+
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
595+
}
596+
}
597+
598+
TEST(Expr, OutOfBounds2dFlattenedIndex) {
599+
ExprHandle buf_size(149);
600+
ExprHandle start(0);
601+
ExprHandle stopInner(15);
602+
ExprHandle stopOuter(10);
603+
VarHandle i("i", kInt);
604+
VarHandle j("j", kInt);
605+
606+
BufHandle X("X", {buf_size}, kInt);
607+
608+
auto idx = Add::make(Mul::make(i, stopInner), j);
609+
auto body = Store::make(X, {idx}, i);
610+
auto inner = For::make(j, start, stopInner, body);
611+
auto stmt = For::make(i, start, stopOuter, inner);
612+
613+
PaddedBuffer<int> data(400);
614+
615+
EXPECT_ANY_THROW(SimpleIREvaluator(stmt, {X})(data));
616+
}
617+
559618
void testCond01() {
560619
const int N = 16;
561620
PaddedBuffer<float> a_v(N);

test/cpp/tensorexpr/test_kernel.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,5 +1677,34 @@ TEST_F(Kernel, DISABLED_FlattenVectorize) {
16771677
#endif
16781678
}
16791679

1680+
TEST_F(Kernel, Strided1dWithinBounds) {
1681+
auto ir = R"IR(
1682+
graph(%0 : Float(3, strides=[1], device=cpu),
1683+
%1 : Float(3, strides=[2], device=cpu)):
1684+
%2 : int = prim::Constant[value=1]()
1685+
%3 : Float(3, strides=[1]) = aten::add(%0, %1, %2)
1686+
return (%3))IR";
1687+
auto graph = std::make_shared<Graph>();
1688+
std::unordered_map<std::string, Value*> vmap;
1689+
parseIR(ir, graph.get(), vmap);
1690+
TensorExprKernel k(graph);
1691+
1692+
auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat));
1693+
auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat))
1694+
.index({Slice(None, None, 2)});
1695+
auto expect = a + b;
1696+
1697+
std::vector<at::Tensor> inputs = {a, b};
1698+
1699+
std::vector<IValue> stack = fmap<IValue>(inputs);
1700+
k.run(stack);
1701+
1702+
auto output = stack[0].toTensor();
1703+
1704+
for (size_t i = 0; i < 3; ++i) {
1705+
CHECK_EQ(((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]);
1706+
}
1707+
}
1708+
16801709
} // namespace jit
16811710
} // namespace torch

test/test_tensorexpr_pybind.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def test_external_calls(self):
6969

7070
def test_dynamic_shape(self):
7171
dN = te.VarHandle(torch.int32)
72-
A = te.BufHandle(torch.float64)
73-
B = te.BufHandle(torch.float64)
72+
A = te.BufHandle([dN], torch.float64)
73+
B = te.BufHandle([dN], torch.float64)
7474

7575
def compute(i):
7676
return A.load(i) - B.load(i)
@@ -92,6 +92,32 @@ def test_with_shape(n):
9292
test_with_shape(8)
9393
test_with_shape(31)
9494

95+
def test_dynamic_shape_2d(self):
96+
dN = te.VarHandle(torch.int32)
97+
dM = te.VarHandle(torch.int32)
98+
A = te.BufHandle([dN, dM], torch.float64)
99+
B = te.BufHandle([dN, dM], torch.float64)
100+
101+
def compute(i, j):
102+
return A.load([i, j]) - B.load([i, j])
103+
104+
C = te.Compute("C", [dN, dM], compute)
105+
106+
loopnest = te.LoopNest([C])
107+
loopnest.prepare_for_codegen()
108+
109+
cg = te.construct_codegen("ir_eval", loopnest.simplify(), [A, B, C, dN, dM])
110+
111+
def test_with_shape(n, m):
112+
tA = torch.randn(n, m, dtype=torch.double)
113+
tB = torch.randn(n, m, dtype=torch.double)
114+
tC = torch.empty(n, m, dtype=torch.double)
115+
cg.call([tA, tB, tC, n, m])
116+
torch.testing.assert_close(tA - tB, tC)
117+
118+
test_with_shape(2, 4)
119+
test_with_shape(5, 3)
120+
95121
def test_dtype_error(self):
96122
te.BufHandle("a", [1], torch.float32) # ok
97123
self.assertRaises(TypeError, lambda: te.BufHandle("a", [1], "float55"))

torch/csrc/jit/tensorexpr/eval.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,13 +672,74 @@ class SimpleIREvaluatorImpl : public IRVisitor {
672672
return {};
673673
}
674674

675+
void check_bounds_throw(int64_t idx, int64_t bound, const BufPtr& buf) {
676+
std::stringstream ss;
677+
ss << "Index out of bounds in check_bounds. Index: " << idx
678+
<< "; bounds: [0, " << bound << ").";
679+
throw malformed_input(ss.str(), buf);
680+
}
681+
682+
void check_bounds(const BufPtr& buf, const std::vector<ExprPtr>& indices) {
683+
const std::vector<ExprPtr>& dims = buf->dims();
684+
if (dims.size() != indices.size()) {
685+
// indices are flattened, but not buffer
686+
if (indices.size() == 1) {
687+
if (dims.size() != buf->strides().size()) {
688+
throw malformed_input(
689+
"Number of dimensions did not match number of strides", buf);
690+
}
691+
size_t buf_size = 1;
692+
if (dims.size() > 0) {
693+
ExprHandle buf_size_expr = ExprHandle(immLike(dims[0], 1));
694+
ExprHandle negative_one = ExprHandle(immLike(dims[0], -1));
695+
for (const auto& i : c10::irange(dims.size())) {
696+
buf_size_expr = buf_size_expr +
697+
((negative_one + ExprHandle(dims[i])) *
698+
ExprHandle(buf->strides()[i]));
699+
}
700+
buf_size_expr.node()->accept(this);
701+
buf_size = value().intValue();
702+
}
703+
indices[0]->accept(this);
704+
const auto& index_values = indexVec(value());
705+
for (auto& j : index_values) {
706+
if (j < 0 || j >= buf_size) {
707+
check_bounds_throw(j, buf_size, buf);
708+
}
709+
}
710+
return;
711+
}
712+
throw malformed_input(
713+
"dimensions and indices mismatch in check_bounds. Buf has " +
714+
std::to_string(dims.size()) + " dimensions and indices has " +
715+
std::to_string(indices.size()) + " dimensions.",
716+
buf);
717+
}
718+
for (const auto& i : c10::irange(dims.size())) {
719+
auto opt_dim = intValue(dims[i]);
720+
if (!opt_dim) {
721+
continue;
722+
}
723+
auto dim_bound = *opt_dim;
724+
indices[i]->accept(this);
725+
const auto& ithDimIndices = indexVec(value());
726+
for (auto& j : ithDimIndices) {
727+
if (j < 0 || j >= dim_bound) {
728+
check_bounds_throw(j, dim_bound, buf);
729+
}
730+
}
731+
}
732+
}
733+
675734
TORCH_API void visit(LoadPtr v) override {
676735
auto iter = buffer_mapping_.find(v->buf());
677736
if (iter == buffer_mapping_.end()) {
678737
throw malformed_input("could not find base node in Load", v);
679738
}
680739
void* ptr = iter->second;
681740

741+
check_bounds(v->buf(), v->indices());
742+
682743
ExprPtr flat_idx =
683744
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
684745
flat_idx->accept(this);
@@ -722,6 +783,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
722783

723784
void* ptr = iter->second;
724785

786+
check_bounds(v->buf(), v->indices());
787+
725788
ExprPtr flat_idx =
726789
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
727790
flat_idx->accept(this);

0 commit comments

Comments
 (0)