@@ -1584,15 +1584,13 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
1584
1584
// ===----------------------------------------------------------------------===//
1585
1585
1586
1586
OpFoldResult cir::VecCmpOp::fold (FoldAdaptor adaptor) {
1587
- mlir::Attribute lhs = adaptor.getLhs ();
1588
- mlir::Attribute rhs = adaptor.getRhs ();
1589
- if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
1590
- !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
1587
+ auto lhsVecAttr =
1588
+ mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getLhs ());
1589
+ auto rhsVecAttr =
1590
+ mlir::dyn_cast_if_present<cir::ConstVectorAttr>(adaptor.getRhs ());
1591
+ if (!lhsVecAttr || !rhsVecAttr)
1591
1592
return {};
1592
1593
1593
- auto lhsVecAttr = mlir::cast<cir::ConstVectorAttr>(lhs);
1594
- auto rhsVecAttr = mlir::cast<cir::ConstVectorAttr>(rhs);
1595
-
1596
1594
mlir::Type inputElemTy =
1597
1595
mlir::cast<cir::VectorType>(lhsVecAttr.getType ()).getElementType ();
1598
1596
if (!isAnyIntegerOrFloatingPointType (inputElemTy))
@@ -1603,17 +1601,15 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1603
1601
mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts ();
1604
1602
uint64_t vecSize = lhsVecElhs.size ();
1605
1603
1606
- auto resultVecTy = mlir::cast<cir::VectorType>(getType ());
1607
-
1608
1604
SmallVector<mlir::Attribute, 16 > elements (vecSize);
1605
+ bool isIntAttr = vecSize ? mlir::isa<cir::IntAttr>(lhsVecElhs[0 ]) : false ;
1609
1606
for (uint64_t i = 0 ; i < vecSize; i++) {
1610
1607
mlir::Attribute lhsAttr = lhsVecElhs[i];
1611
1608
mlir::Attribute rhsAttr = rhsVecElhs[i];
1612
-
1613
1609
int cmpResult = 0 ;
1614
1610
switch (opKind) {
1615
1611
case cir::CmpOpKind::lt: {
1616
- if (mlir::isa<cir::IntAttr>(lhsAttr) ) {
1612
+ if (isIntAttr ) {
1617
1613
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt () <
1618
1614
mlir::cast<cir::IntAttr>(rhsAttr).getSInt ();
1619
1615
} else {
@@ -1623,7 +1619,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1623
1619
break ;
1624
1620
}
1625
1621
case cir::CmpOpKind::le: {
1626
- if (mlir::isa<cir::IntAttr>(lhsAttr) ) {
1622
+ if (isIntAttr ) {
1627
1623
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt () <=
1628
1624
mlir::cast<cir::IntAttr>(rhsAttr).getSInt ();
1629
1625
} else {
@@ -1633,7 +1629,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1633
1629
break ;
1634
1630
}
1635
1631
case cir::CmpOpKind::gt: {
1636
- if (mlir::isa<cir::IntAttr>(lhsAttr) ) {
1632
+ if (isIntAttr ) {
1637
1633
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt () >
1638
1634
mlir::cast<cir::IntAttr>(rhsAttr).getSInt ();
1639
1635
} else {
@@ -1643,7 +1639,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1643
1639
break ;
1644
1640
}
1645
1641
case cir::CmpOpKind::ge: {
1646
- if (mlir::isa<cir::IntAttr>(lhsAttr) ) {
1642
+ if (isIntAttr ) {
1647
1643
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt () >=
1648
1644
mlir::cast<cir::IntAttr>(rhsAttr).getSInt ();
1649
1645
} else {
@@ -1653,7 +1649,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1653
1649
break ;
1654
1650
}
1655
1651
case cir::CmpOpKind::eq: {
1656
- if (mlir::isa<cir::IntAttr>(lhsAttr) ) {
1652
+ if (isIntAttr ) {
1657
1653
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt () ==
1658
1654
mlir::cast<cir::IntAttr>(rhsAttr).getSInt ();
1659
1655
} else {
@@ -1663,7 +1659,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1663
1659
break ;
1664
1660
}
1665
1661
case cir::CmpOpKind::ne: {
1666
- if (mlir::isa<cir::IntAttr>(lhsAttr) ) {
1662
+ if (isIntAttr ) {
1667
1663
cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt () !=
1668
1664
mlir::cast<cir::IntAttr>(rhsAttr).getSInt ();
1669
1665
} else {
@@ -1674,7 +1670,7 @@ OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
1674
1670
}
1675
1671
}
1676
1672
1677
- elements[i] = cir::IntAttr::get (resultVecTy .getElementType (), cmpResult);
1673
+ elements[i] = cir::IntAttr::get (getType () .getElementType (), cmpResult);
1678
1674
}
1679
1675
1680
1676
return cir::ConstVectorAttr::get (
0 commit comments