Skip to content

[InstCombine] Offset both sides of an equality icmp #134086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
@@ -5808,6 +5808,128 @@ static Instruction *foldICmpPow2Test(ICmpInst &I,
return nullptr;
}

/// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified.
using OffsetOp = std::pair<Instruction::BinaryOps, Value *>;
static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets,
bool AllowRecursion) {
Instruction *Inst = dyn_cast<Instruction>(V);
if (!Inst || !Inst->hasOneUse())
return;

switch (Inst->getOpcode()) {
case Instruction::Add:
Offsets.emplace_back(Instruction::Sub, Inst->getOperand(1));
Offsets.emplace_back(Instruction::Sub, Inst->getOperand(0));
break;
case Instruction::Sub:
Offsets.emplace_back(Instruction::Add, Inst->getOperand(1));
break;
case Instruction::Xor:
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
break;
case Instruction::Select:
if (AllowRecursion) {
collectOffsetOp(Inst->getOperand(1), Offsets, /*AllowRecursion=*/false);
collectOffsetOp(Inst->getOperand(2), Offsets, /*AllowRecursion=*/false);
}
break;
default:
break;
}
}

enum class OffsetKind { Invalid, Value, Select };

struct OffsetResult {
OffsetKind Kind;
Value *V0, *V1, *V2;

static OffsetResult invalid() {
return {OffsetKind::Invalid, nullptr, nullptr, nullptr};
}
static OffsetResult value(Value *V) {
return {OffsetKind::Value, V, nullptr, nullptr};
}
static OffsetResult select(Value *Cond, Value *TrueV, Value *FalseV) {
return {OffsetKind::Select, Cond, TrueV, FalseV};
}
bool isValid() const { return Kind != OffsetKind::Invalid; }
Value *materialize(InstCombiner::BuilderTy &Builder) const {
switch (Kind) {
case OffsetKind::Invalid:
llvm_unreachable("Invalid offset result");
case OffsetKind::Value:
return V0;
case OffsetKind::Select:
return Builder.CreateSelect(V0, V1, V2);
}
}
};

/// Offset both sides of an equality icmp to see if we can save some
/// instructions: icmp eq/ne X, Y -> icmp eq/ne X op Z, Y op Z.
/// Note: This operation should not introduce poison.
static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
InstCombiner::BuilderTy &Builder,
const SimplifyQuery &SQ) {
assert(I.isEquality() && "Expected an equality icmp");
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (!Op0->getType()->isIntOrIntVectorTy())
return nullptr;

SmallVector<OffsetOp, 4> OffsetOps;
collectOffsetOp(Op0, OffsetOps, /*AllowRecursion=*/true);
collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true);

auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * {
Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ);
// Avoid infinite loops by checking if RHS is an identity for the BinOp.
if (!Simplified || Simplified == V)
return nullptr;
// Reject constant expressions as they don't simplify things.
if (isa<Constant>(Simplified) && !match(Simplified, m_ImmConstant()))
return nullptr;
// Check if the transformation introduces poison.
return impliesPoison(RHS, V) ? Simplified : nullptr;
};

auto ApplyOffset = [&](Value *V, unsigned BinOpc,
Value *RHS) -> OffsetResult {
if (auto *Sel = dyn_cast<SelectInst>(V)) {
if (!Sel->hasOneUse())
return OffsetResult::invalid();
Value *TrueVal = ApplyOffsetImpl(Sel->getTrueValue(), BinOpc, RHS);
if (!TrueVal)
return OffsetResult::invalid();
Value *FalseVal = ApplyOffsetImpl(Sel->getFalseValue(), BinOpc, RHS);
if (!FalseVal)
return OffsetResult::invalid();
return OffsetResult::select(Sel->getCondition(), TrueVal, FalseVal);
}
if (Value *Simplified = ApplyOffsetImpl(V, BinOpc, RHS))
return OffsetResult::value(Simplified);
return OffsetResult::invalid();
};

for (auto [BinOp, RHS] : OffsetOps) {
auto BinOpc = static_cast<unsigned>(BinOp);

auto Op0Result = ApplyOffset(Op0, BinOpc, RHS);
if (!Op0Result.isValid())
continue;
auto Op1Result = ApplyOffset(Op1, BinOpc, RHS);
if (!Op1Result.isValid())
continue;

Value *NewLHS = Op0Result.materialize(Builder);
Value *NewRHS = Op1Result.materialize(Builder);
return new ICmpInst(I.getPredicate(), NewLHS, NewRHS);
}

return nullptr;
}

Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
if (!I.isEquality())
return nullptr;
@@ -6054,6 +6176,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
: ConstantInt::getNullValue(A->getType()));
}

if (auto *Res = foldICmpEqualityWithOffset(
I, Builder, getSimplifyQuery().getWithInstruction(&I)))
return Res;

return nullptr;
}

6 changes: 2 additions & 4 deletions llvm/test/Transforms/InstCombine/icmp-add.ll
Original file line number Diff line number Diff line change
@@ -2380,8 +2380,7 @@ define <2 x i1> @icmp_eq_add_non_splat(<2 x i32> %a) {

define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
; CHECK-LABEL: @icmp_eq_add_undef2(
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 undef>
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 undef>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%add = add <2 x i32> %a, <i32 5, i32 5>
@@ -2391,8 +2390,7 @@ define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {

define <2 x i1> @icmp_eq_add_non_splat2(<2 x i32> %a) {
; CHECK-LABEL: @icmp_eq_add_non_splat2(
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 11>
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 6>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%add = add <2 x i32> %a, <i32 5, i32 5>
3 changes: 1 addition & 2 deletions llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
Original file line number Diff line number Diff line change
@@ -136,8 +136,7 @@ define i1 @foo2(i32 %x, i32 %y) {
define <2 x i1> @foo3(<2 x i8> %x) {
; CHECK-LABEL: @foo3(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -2, i8 -1>
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[XOR]], <i8 9, i8 79>
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[X:%.*]], <i8 -9, i8 -80>
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
entry:
236 changes: 236 additions & 0 deletions llvm/test/Transforms/InstCombine/icmp-select.ll
Original file line number Diff line number Diff line change
@@ -628,3 +628,239 @@ define i1 @icmp_slt_select(i1 %cond, i32 %a, i32 %b) {
%res = icmp slt i32 %lhs, %rhs
ret i1 %res
}

define i1 @discr_eq(i8 %a, i8 %b) {
; CHECK-LABEL: @discr_eq(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%add1 = add i8 %a, -2
%cmp1 = icmp ugt i8 %a, 1
%sel1 = select i1 %cmp1, i8 %add1, i8 1
%add2 = add i8 %b, -2
%cmp2 = icmp ugt i8 %b, 1
%sel2 = select i1 %cmp2, i8 %add2, i8 1
%res = icmp eq i8 %sel1, %sel2
ret i1 %res
}

define i1 @discr_ne(i8 %a, i8 %b) {
; CHECK-LABEL: @discr_ne(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 3
; CHECK-NEXT: [[RES:%.*]] = icmp ne i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%add1 = add i8 %a, -2
%cmp1 = icmp ugt i8 %a, 1
%sel1 = select i1 %cmp1, i8 %add1, i8 1
%add2 = add i8 %b, -2
%cmp2 = icmp ugt i8 %b, 1
%sel2 = select i1 %cmp2, i8 %add2, i8 1
%res = icmp ne i8 %sel1, %sel2
ret i1 %res
}

define i1 @discr_xor_eq(i8 %a, i8 %b) {
; CHECK-LABEL: @discr_xor_eq(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B:%.*]], 1
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 -4
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP2]], i8 [[B]], i8 -4
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%xor1 = xor i8 %a, -3
%cmp1 = icmp ugt i8 %a, 1
%sel1 = select i1 %cmp1, i8 %xor1, i8 1
%xor2 = xor i8 %b, -3
%cmp2 = icmp ugt i8 %b, 1
%sel2 = select i1 %cmp2, i8 %xor2, i8 1
%res = icmp eq i8 %sel1, %sel2
ret i1 %res
}

define i1 @discr_eq_simple(i8 %a, i8 %b) {
; CHECK-LABEL: @discr_eq_simple(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], 1
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 3
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[B:%.*]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%add1 = add i8 %a, -2
%cmp1 = icmp ugt i8 %a, 1
%sel1 = select i1 %cmp1, i8 %add1, i8 1
%add2 = add i8 %b, -2
%res = icmp eq i8 %sel1, %add2
ret i1 %res
}

define i1 @discr_eq_add_commuted(i8 noundef %a, i8 %b, i8 %c, i1 %cond1, i1 %cond2) {
; CHECK-LABEL: @discr_eq_add_commuted(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[COND1:%.*]], i8 [[B:%.*]], i8 0
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND2:%.*]], i8 [[C:%.*]], i8 [[B]]
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%add1 = add i8 %a, %b
%sel1 = select i1 %cond1, i8 %add1, i8 %a
%add2 = add i8 %c, %a
%sel2 = select i1 %cond2, i8 %add2, i8 %add1
%res = icmp eq i8 %sel1, %sel2
ret i1 %res
}

define i1 @discr_eq_add_commuted_implies_poison(i8 %a, i8 %b, i8 %c, i1 %cond1, i1 %cond2) {
; CHECK-LABEL: @discr_eq_add_commuted_implies_poison(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[COND1:%.*]], i8 [[B:%.*]], i8 0
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND2:%.*]], i8 [[C:%.*]], i8 [[B]]
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%add1 = add i8 %a, %b
%sel1 = select i1 %cond1, i8 %add1, i8 %a
%add2 = add i8 %c, %a
%sel2 = select i1 %cond2, i8 %add2, i8 %add1
%res = icmp eq i8 %sel1, %sel2
ret i1 %res
}

define i1 @discr_eq_sub(i8 noundef %a, i8 %b, i8 %c, i1 %cond1, i1 %cond2) {
; CHECK-LABEL: @discr_eq_sub(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = select i1 [[COND1:%.*]], i8 [[B:%.*]], i8 0
; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND2:%.*]], i8 [[C:%.*]], i8 0
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%neg = sub i8 0, %a
%sub1 = sub i8 %b, %a
%sel1 = select i1 %cond1, i8 %sub1, i8 %neg
%sub2 = sub i8 %c, %a
%sel2 = select i1 %cond2, i8 %sub2, i8 %neg
%res = icmp eq i8 %sel1, %sel2
ret i1 %res
}

; Negative tests

define i1 @discr_eq_multi_use(i8 %a, i8 %b) {
; CHECK-LABEL: @discr_eq_multi_use(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -2
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
; CHECK-NEXT: call void @use(i8 [[SEL1]])
; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%add1 = add i8 %a, -2
%cmp1 = icmp ugt i8 %a, 1
%sel1 = select i1 %cmp1, i8 %add1, i8 1
call void @use(i8 %sel1)
%add2 = add i8 %b, -2
%cmp2 = icmp ugt i8 %b, 1
%sel2 = select i1 %cmp2, i8 %add2, i8 1
%res = icmp eq i8 %sel1, %sel2
ret i1 %res
}

define i1 @discr_eq_failed_to_simplify(i8 %a, i8 %b) {
; CHECK-LABEL: @discr_eq_failed_to_simplify(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[A:%.*]], -3
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A]], 1
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 [[ADD1]], i8 1
; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[B:%.*]], -2
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i8 [[B]], 1
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[ADD2]], i8 1
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%add1 = add i8 %a, -3
%cmp1 = icmp ugt i8 %a, 1
%sel1 = select i1 %cmp1, i8 %add1, i8 1
%add2 = add i8 %b, -2
%cmp2 = icmp ugt i8 %b, 1
%sel2 = select i1 %cmp2, i8 %add2, i8 1
%res = icmp eq i8 %sel1, %sel2
ret i1 %res
}

define <2 x i1> @discr_eq_simple_vec(<2 x i8> %a, <2 x i8> %b, i1 %cond) {
; CHECK-LABEL: @discr_eq_simple_vec(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[ADD1:%.*]] = add <2 x i8> [[A:%.*]], <i8 poison, i8 -2>
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND:%.*]], <2 x i8> [[ADD1]], <2 x i8> splat (i8 1)
; CHECK-NEXT: [[ADD2:%.*]] = add <2 x i8> [[B:%.*]], <i8 -2, i8 poison>
; CHECK-NEXT: [[RES:%.*]] = icmp eq <2 x i8> [[SEL1]], [[ADD2]]
; CHECK-NEXT: ret <2 x i1> [[RES]]
;
entry:
%add1 = add <2 x i8> %a, <i8 poison, i8 -2>
%sel1 = select i1 %cond, <2 x i8> %add1, <2 x i8> splat(i8 1)
%add2 = add <2 x i8> %b, <i8 -2, i8 poison>
%res = icmp eq <2 x i8> %sel1, %add2
ret <2 x i1> %res
}

define i1 @discr_eq_sub_commuted(i8 noundef %a, i8 %b, i8 %c, i1 %cond1, i1 %cond2) {
; CHECK-LABEL: @discr_eq_sub_commuted(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]]
; CHECK-NEXT: [[SUB1:%.*]] = sub i8 [[A]], [[B:%.*]]
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1:%.*]], i8 [[SUB1]], i8 [[NEG]]
; CHECK-NEXT: [[SUB2:%.*]] = sub i8 [[A]], [[C:%.*]]
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[COND2:%.*]], i8 [[SUB2]], i8 [[NEG]]
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[SEL1]], [[SEL2]]
; CHECK-NEXT: ret i1 [[RES]]
;
entry:
%neg = sub i8 0, %a
%sub1 = sub i8 %a, %b
%sel1 = select i1 %cond1, i8 %sub1, i8 %neg
%sub2 = sub i8 %a, %c
%sel2 = select i1 %cond2, i8 %sub2, i8 %neg
%res = icmp eq i8 %sel1, %sel2
ret i1 %res
}

@g = external global i8

; Do not introduce constant expressions.
define i1 @discr_eq_constantexpr(ptr %p) {
; CHECK-LABEL: @discr_eq_constantexpr(
; CHECK-NEXT: [[I:%.*]] = ptrtoint ptr [[P:%.*]] to i64
; CHECK-NEXT: [[SUB:%.*]] = sub i64 [[I]], ptrtoint (ptr @g to i64)
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[SUB]], -1
; CHECK-NEXT: ret i1 [[CMP]]
;
%i = ptrtoint ptr %p to i64
%sub = sub i64 %i, ptrtoint (ptr @g to i64)
%cmp = icmp eq i64 %sub, -1
ret i1 %cmp
}