diff --git a/xls/passes/basic_simplification_pass.cc b/xls/passes/basic_simplification_pass.cc index 5e92e5263f..798ad84b92 100644 --- a/xls/passes/basic_simplification_pass.cc +++ b/xls/passes/basic_simplification_pass.cc @@ -595,6 +595,96 @@ absl::StatusOr MatchPatterns(Node* n) { return true; } + // Pattern: AND subset check (mask subset). + // + // This is also commonly described as an AND "absorption" rewrite: if y has no + // bits set outside x, then `y & x == y`. The rewritten form makes that "no + // bits set outside x" condition explicit as `(y & ~x) == 0`. + // + // Why "subset": + // Think of a bitvector as the set of positions where it has a 1-bit. + // Then eq(and(x, y), y) holds when that set for y is a subset of the + // set for x. y has no bits set outside the 1-bits of x. + // + // For same-width bitvectors, the following are equivalent: + // - `eq(and(x, y), y)` means: every 1-bit in y is also a 1-bit in x + // + // - `eq(and(not(x), y), 0)` means: y has no 1-bits in positions where x has a + // 0-bit (equivalently, and(not(x), y) has no 1-bits at all). + // + // So we can rewrite: + // `eq(and(x, y), y)` <=> `eq(and(not(x), y), 0)` + // `eq(and(x, y), x)` <=> `eq(and(not(y), x), 0)` + // + // This generalizes to n-ary and: + // + // `eq(and(a, b, ..., t), t)` <=> `eq(and(not(and(a, b, ...)), t), 0)` + // + // And similarly for `ne`. + if (n->op() == Op::kEq || n->op() == Op::kNe) { + auto try_rewrite_and_subset_check = + [&](Node* and_node, Node* subset) -> absl::StatusOr { + if (and_node->op() != Op::kAnd) { + return false; + } + + std::vector mask_operands; + mask_operands.reserve(and_node->operand_count()); + bool found_subset = false; + for (Node* operand : and_node->operands()) { + if (operand == subset) { + found_subset = true; + continue; + } + mask_operands.push_back(operand); + } + if (!found_subset) { + return false; + } + + FunctionBase* f = n->function_base(); + Node* mask = nullptr; + if (mask_operands.empty()) { + XLS_ASSIGN_OR_RETURN( + mask, + f->MakeNode(n->loc(), AllOnesOfType(subset->GetType()))); + } else if (mask_operands.size() == 1) { + mask = mask_operands[0]; + } else { + XLS_ASSIGN_OR_RETURN( + mask, f->MakeNode(n->loc(), mask_operands, Op::kAnd)); + } + + XLS_ASSIGN_OR_RETURN(Node * not_mask, + f->MakeNode(n->loc(), mask, Op::kNot)); + XLS_ASSIGN_OR_RETURN( + Node * disallowed_bits, + f->MakeNode(n->loc(), std::vector{not_mask, subset}, + Op::kAnd)); + XLS_ASSIGN_OR_RETURN( + Node * zero, + f->MakeNode(n->loc(), ZeroOfType(subset->GetType()))); + XLS_ASSIGN_OR_RETURN( + Node * replacement, + f->MakeNode(n->loc(), disallowed_bits, zero, n->op())); + + VLOG(2) << "FOUND: and subset check via compare-to-zero"; + XLS_RETURN_IF_ERROR(n->ReplaceUsesWith(replacement)); + return true; + }; + + XLS_ASSIGN_OR_RETURN(bool rewritten, try_rewrite_and_subset_check( + n->operand(0), n->operand(1))); + if (rewritten) { + return true; + } + XLS_ASSIGN_OR_RETURN( + rewritten, try_rewrite_and_subset_check(n->operand(1), n->operand(0))); + if (rewritten) { + return true; + } + } + // Patterns (where x is a bits[1] type): // eq(x, 1) => x // eq(x, 0) => not(x) diff --git a/xls/passes/basic_simplification_pass_test.cc b/xls/passes/basic_simplification_pass_test.cc index 39d9130cb7..bdb0b065bf 100644 --- a/xls/passes/basic_simplification_pass_test.cc +++ b/xls/passes/basic_simplification_pass_test.cc @@ -675,6 +675,75 @@ TEST_F(BasicSimplificationPassTest, EqOfSwappedTwoWaySelects) { EXPECT_THAT(f->return_value(), m::Eq(m::Param("a"), m::Param("b"))); } +TEST_F(BasicSimplificationPassTest, EqAndSubsetCheckNary) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue z = fb.Param("z", p->GetBitsType(32)); + BValue and_xyz = fb.And({x, y, z}); + fb.Eq(and_xyz, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::Eq(m::And(m::Not(m::And(m::Param("x"), m::Param("z"))), m::Param("y")), + m::Literal(0))); +} + +TEST_F(BasicSimplificationPassTest, EqAndSubsetCheckNaryCommutedEq) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue z = fb.Param("z", p->GetBitsType(32)); + BValue and_xyz = fb.And({x, y, z}); + fb.Eq(y, and_xyz); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::Eq(m::And(m::Not(m::And(m::Param("x"), m::Param("z"))), m::Param("y")), + m::Literal(0))); +} + +TEST_F(BasicSimplificationPassTest, NeAndSubsetCheckNary) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue z = fb.Param("z", p->GetBitsType(32)); + BValue and_xyz = fb.And({x, y, z}); + fb.Ne(and_xyz, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::Ne(m::And(m::Not(m::And(m::Param("x"), m::Param("z"))), m::Param("y")), + m::Literal(0))); +} + +TEST_F(BasicSimplificationPassTest, EqAndSubsetCheckDoesNotMatchNonOperand) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue z = fb.Param("z", p->GetBitsType(32)); + BValue and_xy = fb.And({x, y}); + fb.Eq(and_xy, z); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + ASSERT_THAT(Run(p.get()), IsOkAndHolds(false)); + EXPECT_THAT(f->return_value(), + m::Eq(m::And(m::Param("x"), m::Param("y")), m::Param("z"))); +} + TEST_F(BasicSimplificationPassTest, SubOfSwappedTwoWaySelectsDoesNotSimplify) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get());