diff --git a/xls/passes/basic_simplification_pass.cc b/xls/passes/basic_simplification_pass.cc index 5e92e5263f..274e5dad74 100644 --- a/xls/passes/basic_simplification_pass.cc +++ b/xls/passes/basic_simplification_pass.cc @@ -68,6 +68,66 @@ bool IsBinaryIrreflexiveRelation(Node* node) { } } +struct EqNeSelectNotOrIncMatch { + Node* x; + Op cmp; +}; + +// Matches either: +// (eq|ne)(sel(c, cases=[not(x), add(x, 1)]), 0) +// (eq|ne)(0, sel(c, cases=[not(x), add(x, 1)])) +// +// Also handles swapped select arms. +// +// This is exploiting "not/inc zero-equivalence": the arms `~x` and `x+1` are +// both equal to zero under the same precondition, namely `x == all_ones`. +absl::StatusOr> +TryMatchEqOrNeSelectNotOrIncAgainstZero(Node* n) { + if (n->op() != Op::kEq && n->op() != Op::kNe) { + return std::nullopt; + } + Node* maybe_sel = nullptr; + if (IsLiteralZero(n->operand(0))) { + maybe_sel = n->operand(1); + } else if (IsLiteralZero(n->operand(1))) { + maybe_sel = n->operand(0); + } else { + return std::nullopt; + } + + XLS_ASSIGN_OR_RETURN(std::optional sel, + MatchBinarySelectLike(maybe_sel)); + if (!sel.has_value()) { + return std::nullopt; + } + + auto match_add_one = [](Node* node) -> Node* { + if (node->op() != Op::kAdd) { + return nullptr; + } + if (IsLiteralUnsignedOne(node->operand(0))) { + return node->operand(1); + } + if (IsLiteralUnsignedOne(node->operand(1))) { + return node->operand(0); + } + return nullptr; + }; + + // on_false = not(x), on_true = add(x, 1) + if (Node* x = match_add_one(sel->on_true); + x != nullptr && IsNotOf(sel->on_false, x)) { + return EqNeSelectNotOrIncMatch{.x = x, .cmp = n->op()}; + } + // swapped arms + if (Node* x = match_add_one(sel->on_false); + x != nullptr && IsNotOf(sel->on_true, x)) { + return EqNeSelectNotOrIncMatch{.x = x, .cmp = n->op()}; + } + + return std::nullopt; +} + // MatchPatterns matches simple tree patterns to find opportunities // for simplification. // @@ -94,6 +154,32 @@ absl::StatusOr MatchPatterns(Node* n) { return true; } + // Pattern: + // `(eq|ne)(sel(c, cases=[not(x), add(x, 1)]), 0)` => + // `(eq|ne)(x, all_ones(width(x)))` + // + // Why this works (not/inc zero-equivalence): + // - `(~x == 0)` <=> `(x == all_ones)` + // - `(x+1 == 0)` <=> `(x == all_ones)` (note: given wraparound arithmetic) + // Therefore for `y = sel(c, [~x, x+1])`, `(y != 0)` is exactly `(x != + // all_ones)`, independent of selector `c`. + XLS_ASSIGN_OR_RETURN(std::optional not_or_inc_match, + TryMatchEqOrNeSelectNotOrIncAgainstZero(n)); + if (not_or_inc_match.has_value()) { + Node* x = not_or_inc_match->x; + XLS_ASSIGN_OR_RETURN( + Node * all_ones, + n->function_base()->MakeNode( + n->loc(), Value(Bits::AllOnes(x->BitCountOrDie())))); + VLOG(2) + << "FOUND: not/inc zero-equivalence: " + "(eq|ne)(sel(c, [not(x), add(x, 1)]), 0) => (eq|ne)(x, all_ones)"; + XLS_RETURN_IF_ERROR( + n->ReplaceUsesWithNew(x, all_ones, not_or_inc_match->cmp) + .status()); + return true; + } + // Returns true if all operands of 'node' are the same. auto all_operands_same = [&query_engine](Node* node) { return std::all_of(node->operands().begin(), node->operands().end(), diff --git a/xls/passes/basic_simplification_pass_test.cc b/xls/passes/basic_simplification_pass_test.cc index 39d9130cb7..2c36365ff9 100644 --- a/xls/passes/basic_simplification_pass_test.cc +++ b/xls/passes/basic_simplification_pass_test.cc @@ -693,6 +693,32 @@ TEST_F(BasicSimplificationPassTest, SubOfSwappedTwoWaySelectsDoesNotSimplify) { m::Select(m::Param("p"), {m::Param("b"), m::Param("a")}))); } +TEST_F(BasicSimplificationPassTest, EqNeOfSelectNotOrIncrementAgainstZero) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue c = fb.Param("c", p->GetBitsType(1)); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue nx = fb.Not(x); + BValue one = fb.Literal(UBits(1, 32)); + BValue inc = fb.Add(x, one); + BValue s = fb.Select(c, {nx, inc}); + BValue zero = fb.Literal(UBits(0, 32)); + BValue eq = fb.Eq(s, zero); + BValue ne = fb.Ne(s, zero); + BValue r = fb.Tuple({eq, ne}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(r)); + + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::Tuple( + testing::AnyOf(m::Eq(m::Param("x"), m::Literal(Bits::AllOnes(32))), + m::Eq(m::Literal(Bits::AllOnes(32)), m::Param("x"))), + testing::AnyOf(m::Ne(m::Param("x"), m::Literal(Bits::AllOnes(32))), + m::Ne(m::Literal(Bits::AllOnes(32)), m::Param("x"))))); +} + TEST_F(BasicSimplificationPassTest, AddWithZero) { auto p = CreatePackage(); XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(