Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions xls/passes/basic_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,66 @@ bool IsBinaryIrreflexiveRelation(Node* node) {
}
}

struct EqNeSelectNotOrIncMatch {
Node* x;
Op cmp;
};

// Matches either:
// (eq|ne)(sel(c, cases=[not(x), add(x, 1)]), 0)
// (eq|ne)(0, sel(c, cases=[not(x), add(x, 1)]))
//
// Also handles swapped select arms.
//
// This is exploiting "not/inc zero-equivalence": the arms `~x` and `x+1` are
// both equal to zero under the same precondition, namely `x == all_ones`.
absl::StatusOr<std::optional<EqNeSelectNotOrIncMatch>>
TryMatchEqOrNeSelectNotOrIncAgainstZero(Node* n) {
if (n->op() != Op::kEq && n->op() != Op::kNe) {
return std::nullopt;
}
Node* maybe_sel = nullptr;
if (IsLiteralZero(n->operand(0))) {
maybe_sel = n->operand(1);
} else if (IsLiteralZero(n->operand(1))) {
maybe_sel = n->operand(0);
} else {
return std::nullopt;
}

XLS_ASSIGN_OR_RETURN(std::optional<BinarySelectView> sel,
MatchBinarySelectLike(maybe_sel));
if (!sel.has_value()) {
return std::nullopt;
}

auto match_add_one = [](Node* node) -> Node* {
if (node->op() != Op::kAdd) {
return nullptr;
}
if (IsLiteralUnsignedOne(node->operand(0))) {
return node->operand(1);
}
if (IsLiteralUnsignedOne(node->operand(1))) {
return node->operand(0);
}
return nullptr;
};

// on_false = not(x), on_true = add(x, 1)
if (Node* x = match_add_one(sel->on_true);
x != nullptr && IsNotOf(sel->on_false, x)) {
return EqNeSelectNotOrIncMatch{.x = x, .cmp = n->op()};
}
// swapped arms
if (Node* x = match_add_one(sel->on_false);
x != nullptr && IsNotOf(sel->on_true, x)) {
return EqNeSelectNotOrIncMatch{.x = x, .cmp = n->op()};
}

return std::nullopt;
}

// MatchPatterns matches simple tree patterns to find opportunities
// for simplification.
//
Expand All @@ -94,6 +154,32 @@ absl::StatusOr<bool> MatchPatterns(Node* n) {
return true;
}

// Pattern:
// `(eq|ne)(sel(c, cases=[not(x), add(x, 1)]), 0)` =>
// `(eq|ne)(x, all_ones(width(x)))`
//
// Why this works (not/inc zero-equivalence):
// - `(~x == 0)` <=> `(x == all_ones)`
// - `(x+1 == 0)` <=> `(x == all_ones)` (note: given wraparound arithmetic)
// Therefore for `y = sel(c, [~x, x+1])`, `(y != 0)` is exactly `(x !=
// all_ones)`, independent of selector `c`.
XLS_ASSIGN_OR_RETURN(std::optional<EqNeSelectNotOrIncMatch> not_or_inc_match,
TryMatchEqOrNeSelectNotOrIncAgainstZero(n));
if (not_or_inc_match.has_value()) {
Node* x = not_or_inc_match->x;
XLS_ASSIGN_OR_RETURN(
Node * all_ones,
n->function_base()->MakeNode<Literal>(
n->loc(), Value(Bits::AllOnes(x->BitCountOrDie()))));
VLOG(2)
<< "FOUND: not/inc zero-equivalence: "
"(eq|ne)(sel(c, [not(x), add(x, 1)]), 0) => (eq|ne)(x, all_ones)";
XLS_RETURN_IF_ERROR(
n->ReplaceUsesWithNew<CompareOp>(x, all_ones, not_or_inc_match->cmp)
.status());
return true;
}

// Returns true if all operands of 'node' are the same.
auto all_operands_same = [&query_engine](Node* node) {
return std::all_of(node->operands().begin(), node->operands().end(),
Expand Down
26 changes: 26 additions & 0 deletions xls/passes/basic_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,32 @@ TEST_F(BasicSimplificationPassTest, SubOfSwappedTwoWaySelectsDoesNotSimplify) {
m::Select(m::Param("p"), {m::Param("b"), m::Param("a")})));
}

TEST_F(BasicSimplificationPassTest, EqNeOfSelectNotOrIncrementAgainstZero) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
BValue c = fb.Param("c", p->GetBitsType(1));
BValue x = fb.Param("x", p->GetBitsType(32));
BValue nx = fb.Not(x);
BValue one = fb.Literal(UBits(1, 32));
BValue inc = fb.Add(x, one);
BValue s = fb.Select(c, {nx, inc});
BValue zero = fb.Literal(UBits(0, 32));
BValue eq = fb.Eq(s, zero);
BValue ne = fb.Ne(s, zero);
BValue r = fb.Tuple({eq, ne});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(r));

ScopedVerifyEquivalence sve(f);
ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));
EXPECT_THAT(
f->return_value(),
m::Tuple(
testing::AnyOf(m::Eq(m::Param("x"), m::Literal(Bits::AllOnes(32))),
m::Eq(m::Literal(Bits::AllOnes(32)), m::Param("x"))),
testing::AnyOf(m::Ne(m::Param("x"), m::Literal(Bits::AllOnes(32))),
m::Ne(m::Literal(Bits::AllOnes(32)), m::Param("x")))));
}

TEST_F(BasicSimplificationPassTest, AddWithZero) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
Expand Down