@@ -1673,80 +1673,6 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
16731673
16741674} // namespace
16751675
1676- namespace {
1677- class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1678- public:
1679- using OpConversionPattern::OpConversionPattern;
1680- LogicalResult
1681- matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1682- ConversionPatternRewriter &rewriter) const override {
1683-
1684- Location loc = op->getLoc ();
1685- Value lhs = adaptor.getSelf ();
1686- Value rhs = adaptor.getVec2 ();
1687-
1688- if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1689- return failure ();
1690- }
1691- auto lhsType = dyn_cast<RankedTensorType>(lhs.getType ());
1692- auto rhsType = dyn_cast<RankedTensorType>(rhs.getType ());
1693-
1694- if (!lhsType || !rhsType)
1695- return rewriter.notifyMatchFailure (op,
1696- " outer: expected ranked tensor types" );
1697- if (lhsType.getRank () != 1 || rhsType.getRank () != 1 )
1698- return rewriter.notifyMatchFailure (
1699- op, " outer: expected 1D tensors for outer op lowering" );
1700-
1701- Value lhsDim = getDimOp (rewriter, loc, lhs, 0 );
1702- Value rhsDim = getDimOp (rewriter, loc, rhs, 0 );
1703- Type elementType = lhsType.getElementType ();
1704- Type newResultType = getTypeConverter ()->convertType (op.getType ());
1705-
1706- // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1707- SmallVector<OpFoldResult> resultShape =
1708- getAsOpFoldResult (ValueRange{lhsDim, rhsDim});
1709- Value initTensor =
1710- rewriter.create <tensor::EmptyOp>(loc, resultShape, elementType);
1711-
1712- // Set up affine indexing maps:
1713- // We create a 2D loop iteration space. For the lhs, we use the first index
1714- // (i), for the rhs, the second index (j), and for the result, both (i, j).
1715- AffineMap mapLhs =
1716- AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (0 )},
1717- rewriter.getContext ());
1718- AffineMap mapRhs =
1719- AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (1 )},
1720- rewriter.getContext ());
1721- AffineMap mapOut =
1722- AffineMap::getMultiDimIdentityMap (2 , rewriter.getContext ());
1723-
1724- SmallVector<utils::IteratorType, 2 > iteratorTypes = {
1725- utils::IteratorType::parallel, utils::IteratorType::parallel};
1726-
1727- Value outerProd =
1728- rewriter
1729- .create <linalg::GenericOp>(
1730- loc, initTensor.getType (),
1731- /* inputs=*/ ValueRange{lhsDim, rhsDim},
1732- /* outputs=*/ initTensor,
1733- /* indexingMaps=*/
1734- SmallVector<AffineMap, 3 >{mapLhs, mapRhs, mapOut},
1735- /* iteratortType=*/ iteratorTypes,
1736- [&](OpBuilder &b, Location loc, ValueRange args) {
1737- Value lhsElem = args[0 ];
1738- Value rhsElem = args[1 ];
1739- Value mult = b.create <arith::MulFOp>(loc, lhsElem, rhsElem);
1740- b.create <linalg::YieldOp>(loc, mult);
1741- })
1742- .getResult (0 );
1743-
1744- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1745- return success ();
1746- }
1747- };
1748- } // namespace
1749-
17501676void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality (
17511677 TypeConverter &typeConverter, RewritePatternSet &patterns,
17521678 ConversionTarget &target) {
@@ -1763,6 +1689,4 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
17631689 patterns.add <ConvertAtenConvolutionOp>(typeConverter, context);
17641690 target.addIllegalOp <AtenFftRfftOp>();
17651691 patterns.add <ConvertAtenFftRfftOp>(typeConverter, context);
1766- target.addIllegalOp <AtenOuterOp>();
1767- patterns.add <ConvertAtenOuterOp>(typeConverter, context);
17681692}
0 commit comments