Skip to content

[DemandedBits] Support non-constant shift amounts #148880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions llvm/lib/Analysis/DemandedBits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ void DemandedBits::determineLiveOperandBits(
computeKnownBits(V2, Known2, DL, &AC, UserI, &DT);
}
};
auto GetShiftedRange = [&](unsigned Min, unsigned Max, bool ShiftLeft) {
using ShiftFn = APInt (APInt::*)(unsigned) const;
auto Shift = ShiftLeft ? static_cast<ShiftFn>(&APInt::shl)
: static_cast<ShiftFn>(&APInt::lshr);
AB = APInt::getZero(BitWidth);
for (unsigned ShiftAmount = Min; ShiftAmount <= Max; ++ShiftAmount) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should convert it into an O(LlogL) algorithm with the Exponentiation by squaring trick. In the current implementation, we may need 128 shifts + 128 ors for i128.

Copy link
Author

@karouzakisp karouzakisp Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is L? the i128? 128 + 128 < 128 * log(128) = 896

Also, I don't think the AB |= (AOut.*Shift)(ShiftAmount); is composable like the multiplications.

With multiplication we can do

x^(13) = x · x^(12) = x · (x⁶)² = x · ((x³)²)²

With (AOut) << i or (AOut) << j != (AOut) << (i + j)

for example
i = 1, j = 2
AOut = 0b 0000 0010
AOut << 1 = 0b 0000 0100
AOut << 2 = 0b 0000 1000

(AOut << 1 ) or (AOut << 2) = 0b 0000 1100 != AOut << (1+2)

AOut << 3 = 0b 0001 0000.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L is the length of integer types. A shift or an OR operation in APInt's slow path takes about k*(L/64) instructions. So the current implementation is O(L^2).

Copy link
Author

@karouzakisp karouzakisp Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep intermediate results to avoid doing one by one shifts and then doing the OR at the end of each, here is the alive proof for the lshr case --> https://alive2.llvm.org/ce/z/eeGzyB

this reduces the steps of the loop to log(max - min + 1), resulting to L log(max-min +1) time complexity.

AB |= (AOut.*Shift)(ShiftAmount);
}
};

switch (UserI->getOpcode()) {
default: break;
Expand Down Expand Up @@ -183,6 +192,17 @@ void DemandedBits::determineLiveOperandBits(
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
else if (S->hasNoUnsignedWrap())
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
} else {
ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
// similar to Lshr case
GetShiftedRange(Min, Max, /*ShiftLeft=*/false);
const auto *S = cast<ShlOperator>(UserI);
if (S->hasNoSignedWrap())
AB |= APInt::getHighBitsSet(BitWidth, Max + 1);
else if (S->hasNoUnsignedWrap())
AB |= APInt::getHighBitsSet(BitWidth, Max);
}
}
break;
Expand All @@ -197,6 +217,20 @@ void DemandedBits::determineLiveOperandBits(
// (they must be zero).
if (cast<LShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
} else {
ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
// Suppose AOut == 0b0000 1001
// [min, max] = [1, 3]
// shift by 1 we get 0b0001 0010
// shift by 2 we get 0b0010 0100
// shift by 3 we get 0b0100 1000
// we take the or for every shift to cover all the positions.
//
GetShiftedRange(Min, Max, /*ShiftLeft=*/true);
if (cast<LShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, Max);
}
}
break;
Expand All @@ -217,6 +251,26 @@ void DemandedBits::determineLiveOperandBits(
// (they must be zero).
if (cast<AShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
} else {
ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
GetShiftedRange(Min, Max, /*ShiftLeft=*/true);
if (Max) {
// Suppose AOut = 0011 1100
// [min, max] = [1, 3]
// ShiftAmount = 1 : Mask is 1000 0000
// ShiftAmount = 2 : Mask is 1100 0000
// ShiftAmount = 3 : Mask is 1110 0000
// The Mask with Max covers every case in [min, max],
// so we are done
if ((AOut & APInt::getHighBitsSet(BitWidth, Max)).getBoolValue())
AB.setSignBit();
}
// If the shift is exact, then the low bits are not dead
// (they must be zero).
if (cast<AShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, Max);
}
}
break;
Expand Down
198 changes: 198 additions & 0 deletions llvm/test/Analysis/DemandedBits/ashr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s

define i8 @test_ashr_const_amount_4(i32 %a) {
; CHECK-LABEL: 'test_ashr_const_amount_4'
; CHECK-DAG: DemandedBits: 0xff for %ashr = ashr i32 %a, 4
; CHECK-DAG: DemandedBits: 0xff0 for %a in %ashr = ashr i32 %a, 4
; CHECK-DAG: DemandedBits: 0xffffffff for 4 in %ashr = ashr i32 %a, 4
; CHECK-DAG: DemandedBits: 0xff for %ashr.t = trunc i32 %ashr to i8
; CHECK-DAG: DemandedBits: 0xff for %ashr in %ashr.t = trunc i32 %ashr to i8
;
%ashr = ashr i32 %a, 4
%ashr.t = trunc i32 %ashr to i8
ret i8 %ashr.t
}

define i8 @test_ashr_const_amount_5(i32 %a) {
; CHECK-LABEL: 'test_ashr_const_amount_5'
; CHECK-DAG: DemandedBits: 0xff for %ashr = ashr i32 %a, 5
; CHECK-DAG: DemandedBits: 0x1fe0 for %a in %ashr = ashr i32 %a, 5
; CHECK-DAG: DemandedBits: 0xffffffff for 5 in %ashr = ashr i32 %a, 5
; CHECK-DAG: DemandedBits: 0xff for %ashr.t = trunc i32 %ashr to i8
; CHECK-DAG: DemandedBits: 0xff for %ashr in %ashr.t = trunc i32 %ashr to i8
;
%ashr = ashr i32 %a, 5
%ashr.t = trunc i32 %ashr to i8
ret i8 %ashr.t
}

define i8 @test_ashr_const_amount_8(i32 %a) {
; CHECK-LABEL: 'test_ashr_const_amount_8'
; CHECK-DAG: DemandedBits: 0xff for %ashr = ashr i32 %a, 8
; CHECK-DAG: DemandedBits: 0xff00 for %a in %ashr = ashr i32 %a, 8
; CHECK-DAG: DemandedBits: 0xffffffff for 8 in %ashr = ashr i32 %a, 8
; CHECK-DAG: DemandedBits: 0xff for %ashr.t = trunc i32 %ashr to i8
; CHECK-DAG: DemandedBits: 0xff for %ashr in %ashr.t = trunc i32 %ashr to i8
;
%ashr = ashr i32 %a, 8
%ashr.t = trunc i32 %ashr to i8
ret i8 %ashr.t
}

define i8 @test_ashr_const_amount_9(i32 %a) {

; CHECK-LABEL: 'test_ashr_const_amount_9'
; CHECK-DAG: DemandedBits: 0xff for %ashr.t = trunc i32 %ashr to i8
; CHECK-DAG: DemandedBits: 0xff for %ashr in %ashr.t = trunc i32 %ashr to i8
; CHECK-DAG: DemandedBits: 0xff for %ashr = ashr i32 %a, 8
; CHECK-DAG: DemandedBits: 0xff00 for %a in %ashr = ashr i32 %a, 8
; CHECK-DAG: DemandedBits: 0xffffffff for 8 in %ashr = ashr i32 %a, 8
;
%ashr = ashr i32 %a, 8
%ashr.t = trunc i32 %ashr to i8
ret i8 %ashr.t
}

define i8 @test_ashr(i32 %a, i32 %b) {
; CHECK-LABEL: 'test_ashr'
; CHECK-DAG: DemandedBits: 0xff for %ashr = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %a in %ashr = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %b in %ashr = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xff for %ashr.t = trunc i32 %ashr to i8
; CHECK-DAG: DemandedBits: 0xff for %ashr in %ashr.t = trunc i32 %ashr to i8
;
%ashr = ashr i32 %a, %b
%ashr.t = trunc i32 %ashr to i8
ret i8 %ashr.t
}

define i8 @test_ashr_range_1(i32 %a, i32 %b) {
; CHECK-LABEL: 'test_ashr_range_1'
; CHECK-DAG: DemandedBits: 0xff for %shl.t = trunc i32 %ashr to i8
; CHECK-DAG: DemandedBits: 0xff for %ashr in %shl.t = trunc i32 %ashr to i8
; CHECK-DAG: DemandedBits: 0xffffffff for %b2 = and i32 %b, 3
; CHECK-DAG: DemandedBits: 0x3 for %b in %b2 = and i32 %b, 3
; CHECK-DAG: DemandedBits: 0xffffffff for 3 in %b2 = and i32 %b, 3
; CHECK-DAG: DemandedBits: 0xff for %ashr = ashr i32 %a, %b2
; CHECK-DAG: DemandedBits: 0x7ff for %a in %ashr = ashr i32 %a, %b2
; CHECK-DAG: DemandedBits: 0xffffffff for %b2 in %ashr = ashr i32 %a, %b2
;
%b2 = and i32 %b, 3
%ashr = ashr i32 %a, %b2
%shl.t = trunc i32 %ashr to i8
ret i8 %shl.t
}

define i32 @test_ashr_range_2(i32 %a, i32 %b) {
; CHECK-LABEL: 'test_ashr_range_2'
; CHECK-DAG: DemandedBits: 0xffffffff for %b2 = and i32 %b, 3
; CHECK-DAG: DemandedBits: 0x3 for %b in %b2 = and i32 %b, 3
; CHECK-DAG: DemandedBits: 0xffffffff for 3 in %b2 = and i32 %b, 3
; CHECK-DAG: DemandedBits: 0xffffffff for %ashr = ashr i32 %a, %b2
; CHECK-DAG: DemandedBits: 0xffffffff for %a in %ashr = ashr i32 %a, %b2
; CHECK-DAG: DemandedBits: 0xffffffff for %b2 in %ashr = ashr i32 %a, %b2
;
%b2 = and i32 %b, 3
%ashr = ashr i32 %a, %b2
ret i32 %ashr
}

define i32 @test_ashr_range_3(i32 %a, i32 %b) {
; CHECK-LABEL: 'test_ashr_range_3'
; CHECK-DAG: DemandedBits: 0xffff for %ashr = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %a in %ashr = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %b in %ashr = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %shl = shl i32 %ashr, 16
; CHECK-DAG: DemandedBits: 0xffff for %ashr in %shl = shl i32 %ashr, 16
; CHECK-DAG: DemandedBits: 0xffffffff for 16 in %shl = shl i32 %ashr, 16
;
%ashr = ashr i32 %a, %b
%shl = shl i32 %ashr, 16
ret i32 %shl
}
define i32 @test_ashr_range_4(i32 %a, i32 %b) {
; CHECK-LABEL: 'test_ashr_range_4'
; CHECK-DAG: DemandedBits: 0xffffffff for %shr = lshr i32 %ashr, 8
; CHECK-DAG: DemandedBits: 0xffffff00 for %ashr in %shr = lshr i32 %ashr, 8
; CHECK-DAG: DemandedBits: 0xffffffff for 8 in %shr = lshr i32 %ashr, 8
; CHECK-DAG: DemandedBits: 0xffffff00 for %ashr = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffff00 for %a in %ashr = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %b in %ashr = ashr i32 %a, %b
%ashr = ashr i32 %a, %b
%shr = lshr i32 %ashr, 8
ret i32 %shr
}

define i32 @test_ashr_range_5(i32 %a, i32 %b) {
; CHECK-LABEL: 'test_ashr_range_5'
; CHECK-DAG: DemandedBits: 0xffffffff for %2 = and i32 %1, 255
; CHECK-DAG: DemandedBits: 0xff for %1 in %2 = and i32 %1, 255
; CHECK-DAG: DemandedBits: 0xffffffff for 255 in %2 = and i32 %1, 255
; CHECK-DAG: DemandedBits: 0xff for %1 = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %a in %1 = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %b in %1 = ashr i32 %a, %b
;
%1 = ashr i32 %a, %b
%2 = and i32 %1, 255
ret i32 %2
}

define i32 @test_ashr_range_6(i32 %a, i32 %b) {
; CHECK-LABEL: 'test_ashr_range_6'
; CHECK-DAG: DemandedBits: 0xffff0000 for %lshr.1 = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffff0000 for %a in %lshr.1 = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %b in %lshr.1 = ashr i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %lshr.2 = ashr i32 %lshr.1, 16
; CHECK-DAG: DemandedBits: 0xffff0000 for %lshr.1 in %lshr.2 = ashr i32 %lshr.1, 16
; CHECK-DAG: DemandedBits: 0xffffffff for 16 in %lshr.2 = ashr i32 %lshr.1, 16
;
%lshr.1 = ashr i32 %a, %b
%lshr.2 = ashr i32 %lshr.1, 16
ret i32 %lshr.2
}

define i8 @test_ashr_var_amount(i32 %a, i32 %b){
; CHECK-LABEL: 'test_ashr_var_amount'
; CHECK-DAG: DemandedBits: 0xff for %4 = ashr i32 %1, %3
; CHECK-DAG: DemandedBits: 0xffffffff for %1 in %4 = ashr i32 %1, %3
; CHECK-DAG: DemandedBits: 0xffffffff for %3 in %4 = ashr i32 %1, %3
; CHECK-DAG: DemandedBits: 0xff for %2 = trunc i32 %1 to i8
; CHECK-DAG: DemandedBits: 0xff for %1 in %2 = trunc i32 %1 to i8
; CHECK-DAG: DemandedBits: 0xffffffff for %1 = add nsw i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %a in %1 = add nsw i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %b in %1 = add nsw i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %3 = zext i8 %2 to i32
; CHECK-DAG: DemandedBits: 0xff for %2 in %3 = zext i8 %2 to i32
; CHECK-DAG: DemandedBits: 0xff for %5 = trunc i32 %4 to i8
; CHECK-DAG: DemandedBits: 0xff for %4 in %5 = trunc i32 %4 to i8
;
%1 = add nsw i32 %a, %b
%2 = trunc i32 %1 to i8
%3 = zext i8 %2 to i32
%4 = ashr i32 %1, %3
%5 = trunc i32 %4 to i8
ret i8 %5
}

define i8 @test_ashr_var_amount_nsw(i32 %a, i32 %b){
; CHECK-LABEL 'test_ashr_var_amount_nsw'
; CHECK-DAG: DemandedBits: 0xff for %5 = trunc i32 %4 to i8
; CHECK-DAG: DemandedBits: 0xff for %4 in %5 = trunc i32 %4 to i8
; CHECK-DAG: DemandedBits: 0xffffffff for %1 = add nsw i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %a in %1 = add nsw i32 %a, %b
; CHECK-DAG: DemandedBits: 0xffffffff for %b in %1 = add nsw i32 %a, %b
; CHECK-DAG: DemandedBits: 0xff for %2 = trunc i32 %1 to i8
; CHECK-DAG: DemandedBits: 0xff for %1 in %2 = trunc i32 %1 to i8
; CHECK-DAG: DemandedBits: 0xffffffff for %3 = zext i8 %2 to i32
; CHECK-DAG: DemandedBits: 0xff for %2 in %3 = zext i8 %2 to i32
; CHECK-DAG: DemandedBits: 0xff for %4 = ashr exact i32 %1, %3
; CHECK-DAG: DemandedBits: 0xffffffff for %1 in %4 = ashr exact i32 %1, %3
; CHECK-DAG: DemandedBits: 0xffffffff for %3 in %4 = ashr exact i32 %1, %3
;
%1 = add nsw i32 %a, %b
%2 = trunc i32 %1 to i8
%3 = zext i8 %2 to i32
%4 = ashr exact i32 %1, %3
%5 = trunc i32 %4 to i8
ret i8 %5
}
Loading
Loading