Skip to content

Commit 78af912

Browse files
committed
Address code review comments
1 parent c2e3c51 commit 78af912

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,15 +1584,13 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
15841584
//===----------------------------------------------------------------------===//
15851585

15861586
OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1587-
mlir::Attribute lhs = adaptor.getLhs();
1588-
mlir::Attribute rhs = adaptor.getRhs();
1589-
if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
1590-
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
1587+
auto lhsVecAttr =
1588+
mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs());
1589+
auto rhsVecAttr =
1590+
mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs());
1591+
if (!lhsVecAttr || !rhsVecAttr)
15911592
return {};
15921593

1593-
auto lhsVecAttr = mlir::cast<cir::ConstVectorAttr>(lhs);
1594-
auto rhsVecAttr = mlir::cast<cir::ConstVectorAttr>(rhs);
1595-
15961594
mlir::Type inputElemTy =
15971595
mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
15981596
if (!isAnyIntegerOrFloatingPointType(inputElemTy))
@@ -1603,17 +1601,15 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
16031601
mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
16041602
uint64_t vecSize = lhsVecElhs.size();
16051603

1606-
auto resultVecTy = mlir::cast<cir::VectorType>(getType());
1607-
16081604
SmallVector<mlir::Attribute, 16> elements(vecSize);
1605+
bool isIntAttr = vecSize ? mlir::isa<cir::IntAttr>(lhsVecElhs[0]) : false;
16091606
for (uint64_t i = 0; i < vecSize; i++) {
16101607
mlir::Attribute lhsAttr = lhsVecElhs[i];
16111608
mlir::Attribute rhsAttr = rhsVecElhs[i];
1612-
16131609
int cmpResult = 0;
16141610
switch (opKind) {
16151611
case cir::CmpOpKind::lt: {
1616-
if (mlir::isa<cir::IntAttr>(lhsAttr)) {
1612+
if (isIntAttr) {
16171613
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
16181614
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
16191615
} else {
@@ -1623,7 +1619,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
16231619
break;
16241620
}
16251621
case cir::CmpOpKind::le: {
1626-
if (mlir::isa<cir::IntAttr>(lhsAttr)) {
1622+
if (isIntAttr) {
16271623
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
16281624
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
16291625
} else {
@@ -1633,7 +1629,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
16331629
break;
16341630
}
16351631
case cir::CmpOpKind::gt: {
1636-
if (mlir::isa<cir::IntAttr>(lhsAttr)) {
1632+
if (isIntAttr) {
16371633
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
16381634
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
16391635
} else {
@@ -1643,7 +1639,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
16431639
break;
16441640
}
16451641
case cir::CmpOpKind::ge: {
1646-
if (mlir::isa<cir::IntAttr>(lhsAttr)) {
1642+
if (isIntAttr) {
16471643
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
16481644
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
16491645
} else {
@@ -1653,7 +1649,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
16531649
break;
16541650
}
16551651
case cir::CmpOpKind::eq: {
1656-
if (mlir::isa<cir::IntAttr>(lhsAttr)) {
1652+
if (isIntAttr) {
16571653
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
16581654
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
16591655
} else {
@@ -1663,7 +1659,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
16631659
break;
16641660
}
16651661
case cir::CmpOpKind::ne: {
1666-
if (mlir::isa<cir::IntAttr>(lhsAttr)) {
1662+
if (isIntAttr) {
16671663
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
16681664
mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
16691665
} else {
@@ -1674,7 +1670,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
16741670
}
16751671
}
16761672

1677-
elements[i] = cir::IntAttr::get(resultVecTy.getElementType(), cmpResult);
1673+
elements[i] = cir::IntAttr::get(getType().getElementType(), cmpResult);
16781674
}
16791675

16801676
return cir::ConstVectorAttr::get(

0 commit comments

Comments
 (0)