@@ -1616,6 +1616,223 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern<OpTy> {
16161616};
16171617} // namespace
16181618
1619+ namespace {
1620+ template <typename OpTy, typename PoolingOpTy, int Dim>
1621+ class ConvertRoiAlignOp : public OpConversionPattern <OpTy> {
1622+ public:
1623+ using OpConversionPattern<OpTy>::OpConversionPattern;
1624+ LogicalResult
1625+ matchAndRewrite (OpTy op, typename OpTy::Adaptor adaptor,
1626+ ConversionPatternRewriter &rewriter) const override {
1627+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
1628+ return failure ();
1629+
1630+ Location loc = op->getLoc ();
1631+ const TypeConverter *typeConverter = this ->getTypeConverter ();
1632+ Value result = op.getResult ();
1633+
1634+ uint64_t pooledHeight =
1635+ cast<ConstantIntOp>(op.getPooledHeight ().getDefiningOp ()).getValue ();
1636+ uint64_t pooledWidth =
1637+ cast<ConstantIntOp>(op.getPooledWidth ().getDefiningOp ()).getValue ();
1638+ uint64_t samplingRatio =
1639+ cast<ConstantIntOp>(op.getSamplingRatio ().getDefiningOp ()).getValue ();
1640+ Value pooledH = op.getPooledHeight ();
1641+ Value pooledW = op.getPooledWidth ();
1642+ Value spatialScaleVal = op.getSpatialScale ();
1643+ llvm::APFloat spatialScale =
1644+ cast<ConstantFloatOp>(op.getSpatialScale ().getDefiningOp ()).getValue ();
1645+ Value rois = op.getRois ();
1646+ Value input = op.getInput ();
1647+ // RankedTensorType inputType = input.getType();
1648+ Value offset =
1649+ rewriter.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
1650+ Type resultType = cast<RankedTensorType>(result.getType ());
1651+ Type resultElementType = resultType.getElementType ();
1652+ if (!op.getAligned ()) {
1653+ offset = rewriter.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.5 ));
1654+ }
1655+
1656+ Value lb = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1657+ Value ub0 = rewriter.create <tensor::DimOp>(loc, rois, 0 );
1658+ Value ub1 = rewriter.create <tensor::DimOp>(loc, input, 1 );
1659+ Value step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
1660+ SmallVector<Value> finalOutputShape = {ub0, ub1, pooledH, pooledW};
1661+ Value finalOutputTensor = rewriter.create <tensor::EmptyOp>(
1662+ loc, getAsOpFoldResult (finalOutputShape), resultElementType);
1663+ auto forLoop = rewriter.create <scf::ForOp>(
1664+ loc, lb, ub0, step, ValueRange{},
1665+ [&](OpBuilder &b1, Location loc, Value iv0, ValueRange args) {
1666+ auto forLoop = b1.create <scf::ForOp>(
1667+ loc, lb, ub1, step, ValueRange{},
1668+ [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) {
1669+ // Step 1: Extract bounds for region of interest (roi)
1670+ OpFoldResult zeroAttr = b.getI64IntegerAttr (0 );
1671+ OpFoldResult oneAttr = b.getI64IntegerAttr (1 );
1672+ OpFoldResult twoAttr = b.getI64IntegerAttr (2 );
1673+ OpFoldResult threeAttr = b.getI64IntegerAttr (3 );
1674+ OpFoldResult fourAttr = b.getI64IntegerAttr (4 );
1675+ OpFoldResult fiveAttr = b.getI64IntegerAttr (5 );
1676+ // SmallVector<Value> offsetVals{iv0, zeroAttr};
1677+ // SmallVector<OpFoldResult> sizeVals{oneAttr, fiveAttr};
1678+ SmallVector<OpFoldResult> strideVals{oneAttr, oneAttr, oneAttr,
1679+ oneAttr};
1680+ // Value extractRoiBounds = b.create<tensor::ExtractSliceOp>(
1681+ // loc, rois, offsetVals, sizeVals, strideVals);
1682+ Value lowY = b.create <tensor::ExtractOp>(
1683+ loc, rois, ValueRange{iv0, oneAttr});
1684+ Value lowX = b.create <tensor::ExtractOp>(
1685+ loc, rois, ValueRange{iv0, twoAttr});
1686+ Value highY = b.create <tensor::ExtractOp>(
1687+ loc, rois, ValueRange{iv0, threeAttr});
1688+ Value highX = b.create <tensor::ExtractOp>(
1689+ loc, rois, ValueRange{iv0, fourAttr});
1690+
1691+ lowY = b.create <arith::MulFOp>(loc, lowY, spatialScaleVal);
1692+ lowX = b.create <arith::MulFOp>(loc, lowX, spatialScaleVal);
1693+ highY = b.create <arith::MulFOp>(loc, highY, spatialScaleVal);
1694+ highX = b.create <arith::MulFOp>(loc, highX, spatialScaleVal);
1695+
1696+ lowY = b.create <arith::SubFOp>(loc, lowY, offset);
1697+ lowX = b.create <arith::SubFOp>(loc, lowX, offset);
1698+ highY = b.create <arith::SubFOp>(loc, highY, offset);
1699+ highX = b.create <arith::SubFOp>(loc, highX, offset);
1700+
1701+ // Step 2: Extract region of interest using bounds
1702+ Value lowY_int = b.create <math::FloorOp>(loc, lowY);
1703+ Value lowX_int = b.create <math::FloorOp>(loc, lowX);
1704+ Value highY_int = b.create <math::CeilOp>(loc, highY);
1705+ Value highX_int = b.create <math::CeilOp>(loc, highX);
1706+ lowY_int =
1707+ b.create <arith::FPToSIOp>(loc, b.getI64Type (), lowY_int);
1708+ lowX_int =
1709+ b.create <arith::FPToSIOp>(loc, b.getI64Type (), lowX_int);
1710+ highY_int =
1711+ b.create <arith::FPToSIOp>(loc, b.getI64Type (), highY_int);
1712+ highX_int =
1713+ b.create <arith::FPToSIOp>(loc, b.getI64Type (), highX_int);
1714+
1715+ Value roiHeight =
1716+ b.create <arith::SubIOp>(loc, highY_int, lowY_int);
1717+ Value roiWidth =
1718+ b.create <arith::SubIOp>(loc, highX_int, lowX_int);
1719+
1720+ SmallVector<Value> roiOffsetVals{zeroAttr, iv1, lowY_int,
1721+ lowX_int};
1722+ SmallVector<Value> roiSizeVals{oneAttr, oneAttr, roiHeight,
1723+ roiWidth};
1724+
1725+ Value extractRoi = b.create <tensor::ExtractSliceOp>(
1726+ loc, input, roiOffsetVals, roiSizeVals, strideVals);
1727+
1728+ // Step 3: Perform bilinear interpolation over roi
1729+ Value roiBinH = b.create <arith::SubOp>(loc, highY, lowY);
1730+ Value roiBinW = b.create <arith::SubOp>(loc, highX, lowX);
1731+ Value scaleH = b.create <arith::DivOp>(loc, roiBinH, pooledH);
1732+ Value scaleW = b.create <arith::DivOp>(loc, roiBinW, pooledW);
1733+ scaleH = b.create <arith::CeilOp>(loc, scaleH);
1734+ scaleW = b.create <arith::CeilOp>(loc, scaleW);
1735+ scaleH = b.create <arith::FPToSIOp>(loc, b.getI64Type (), scaleH);
1736+ scaleW = b.create <arith::FPToSIOp>(loc, b.getI64Type (), scaleW);
1737+
1738+ Value roiSampleHeight =
1739+ b.create <arith::MulIOp>(loc, pooledH, scaleH);
1740+ Value roiSampleWidth =
1741+ b.create <arith::MulIOp>(loc, pooledW, scaleW);
1742+
1743+ SmallVector<Value> outputSizeIntValues = {roiSampleHeight,
1744+ roiSampleWidth};
1745+ SmallVector<Value> dims =
1746+ getTensorSizesUntilDim (b, loc, extractRoi, 1 );
1747+ for (unsigned i = 2 ; i < inputRank; i++) {
1748+ dims.push_back (
1749+ castIntToIndex (b, loc, outputSizeIntValues[i - 2 ]));
1750+ }
1751+ SmallVector<Value> inputSizes;
1752+ auto inputType = cast<RankedTensorType>(extractRoi.getType ());
1753+ auto inputRank = inputType.getRank ();
1754+ for (unsigned i = 2 ; i < inputRank; i++) {
1755+ Value inputSize = getDimOp (b, loc, extractRoi, i);
1756+ inputSizes.push_back (b.create <arith::IndexCastOp>(
1757+ loc, b.getIntegerType (64 ), roiSizeVals[i]));
1758+ }
1759+ Value outTensor = b.create <tensor::EmptyOp>(
1760+ loc, getAsOpFoldResult (dims), inputType.getElementType ());
1761+ AffineMap idMap = b.getMultiDimIdentityMap (inputRank);
1762+ SmallVector<utils::IteratorType> iteratorTypes (
1763+ inputRank, utils::IteratorType::parallel);
1764+ Value bilinearInterpolatedRoi =
1765+ b.create <linalg::GenericOp>(
1766+ loc, outTensor.getType (), ValueRange{}, outTensor,
1767+ /* indexingMaps=*/ idMap,
1768+ /* iteratorTypes=*/ iteratorTypes,
1769+ [&](OpBuilder &b, Location loc, ValueRange args) {
1770+ Value retVal = bilinearInterpolate (
1771+ b, op, loc, outputSizeIntValues, extractRoi,
1772+ inputSizes, ValueRange{}, " bilinear" );
1773+ b.create <linalg::YieldOp>(loc, retVal);
1774+ })
1775+ .getResult (0 );
1776+
1777+ // Step 4: Sum pool over interpolated values
1778+ Value sumPool, paddedInput;
1779+ SmallVector<Value> kernelSizeIntValues = {oneAttr, oneAttr,
1780+ scaleH, scaleW};
1781+ SmallVector<Value, 2 > strideInts = {scaleH, scaleW};
1782+ SmallVector<Value, 2 > paddingInts = {zeroAttr, zeroAttr};
1783+ SmallVector<Value, 2 > dilationInts (oneAttr, 2 );
1784+ SmallVector<Value, 4 > outTensorShape;
1785+ if (failed (createPoolingOp<linalg::PoolingNchwSumOp>(
1786+ op, b, self, /* supportNonFPInput=*/ true , false ,
1787+ /* dimensionality=*/ 2 , kernelSizeIntValues, strideInts,
1788+ paddingInts, dilationInts,
1789+ b.getZeroAttr (resultElementType), outTensorShape,
1790+ paddedInput, sumPool)))
1791+ return b.notifyMatchFailure (op, " unable to compute sumpool" );
1792+
1793+ // Step 5: elementwise division by number of sampling points
1794+ // to compute avg pool
1795+ Value outputTensor = b.create <tensor::EmptyOp>(
1796+ loc, getAsOpFoldResult (outTensorShape), resultElementType);
1797+ Value divisor = b.create <arith::MulIOp>(loc, scaleH, scaleW);
1798+ Value avgPool =
1799+ b.create <linalg::GenericOp>(
1800+ loc, outputTensor.getType (), sumPool, outputTensor,
1801+ /* indexingMaps=*/ indexingMapsAvg,
1802+ /* iteratorTypes=*/ iteratorTypesAvg,
1803+ [&](OpBuilder &b, Location loc, ValueRange args) {
1804+ Value avg;
1805+ if (isa<mlir::IntegerType>(resultElementType))
1806+ avg = b.create <arith::DivSIOp>(loc, args[0 ],
1807+ divisor);
1808+ else if (isa<mlir::FloatType>(resultElementType))
1809+ avg =
1810+ b.create <arith::DivFOp>(loc, args[0 ], divisor);
1811+ b.create <linalg::YieldOp>(loc, avg);
1812+ })
1813+ .getResult (0 );
1814+
1815+ SmallVector<OpFoldResult> finalStrides (inputRank, oneAttr);
1816+ SmallVector<OpFoldResult> finalOffsets = {
1817+ getAsOpFoldResult (iv0), getAsOpFoldResult (iv1), zeroAttr,
1818+ zeroAttr};
1819+ SmallVector<OpFoldResult> finalSizes = {
1820+ oneAttr, oneAttr, getAsOpFoldResult (pooledH),
1821+ getAsOpFoldResult (pooledW)};
1822+ SmallVector<OpFoldResult> diagStrides (inputRank, oneAttr);
1823+ finalOutputTensor = b.create <tensor::InsertSliceOp>(
1824+ loc, finalOutputTensor, avgPool, finalOffsets, finalSizes,
1825+ finalStrides);
1826+ });
1827+ });
1828+
1829+ Type resultType = typeConverter->convertType (op.getType ());
1830+ b.replaceOp (op, finalOutputTensor);
1831+ return success ();
1832+ }
1833+ };
1834+ } // namespace
1835+
16191836void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality (
16201837 TypeConverter &typeConverter, RewritePatternSet &patterns,
16211838 ConversionTarget &target) {
0 commit comments