diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index dcef33f93d6ae..2b3af81c0a71d 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -161,16 +161,19 @@ struct GuardElimination { // `checkInputs` check the invariants specified in `removableGuard` // on inputs to `n`. The invariants must hold, or an input must - // be a `prim::Constant` or be of `NumberType` or be included - // as an exception in `except` - bool checkInputs(Node* n, const std::unordered_set& except) { + // be a `prim::Constant` or be of `NumberType` if `allow_numbers` is `true` + // or be included as an exception in `except` + bool checkInputs( + Node* n, + const std::unordered_set& except, + bool allow_numbers = true) { bool all_inputs_guarded = true; size_t i = 0; for (auto input : n->inputs()) { if ((input->node()->kind() == prim::Guard && !input->type()->expect()->isSummarized()) || input->node()->kind() == prim::Constant || - input->type()->isSubtypeOf(NumberType::get()) || + (allow_numbers && input->type()->isSubtypeOf(NumberType::get())) || except.count(i) != 0) { AT_ASSERT( input->node()->kind() != prim::Guard || @@ -260,7 +263,6 @@ struct GuardElimination { case aten::pow: case aten::relu: case aten::threshold: - case aten::avg_pool2d: case prim::AutogradAdd: case prim::AutogradZero: case aten::rand_like: @@ -296,6 +298,8 @@ struct GuardElimination { n->input(3)->node()->kind() == prim::Constant && // the stride is constant n->input(4)->node()->kind() == prim::Constant; + case aten::avg_pool2d: + return checkInputs(n, no_exceptions, false); case aten::unsqueeze: // check that the dimension argument is constant return !n->input(0)->type()->expect()->isSummarized() &&