@@ -1911,37 +1911,41 @@ class DecomposeAtenOuterOp : public OpRewritePattern<AtenOuterOp> {
19111911 auto inputType = cast<BaseTensorType>(input.getType ());
19121912 auto vec2Type = cast<BaseTensorType>(vec2.getType ());
19131913
1914+ // Check if tensors not empty
1915+ if (!inputType.hasSizes () || !vec2Type.hasSizes ()) {
1916+ return rewriter.notifyMatchFailure (
1917+ op, " Inputs must be ranked tensors for aten.outer" );
1918+ }
1919+
19141920 // Check if both tensors are 1-dimensional
19151921 SmallVector<int64_t > inputShape (inputType.getSizes ());
19161922 SmallVector<int64_t > vec2Shape (vec2Type.getSizes ());
19171923
1918- if (inputShape.size () == 1 && vec2Shape.size () == 1 ) {
1924+ if (inputShape.size () != 1 || vec2Shape.size () != 1 ) {
1925+ return rewriter.notifyMatchFailure (
1926+ op, " Inputs must be 1-dimensional vectors for aten.outer" );
1927+ }
19191928
1920- Value one = rewriter.create <Torch::ConstantIntOp>(
1921- loc, rewriter.getI64IntegerAttr (1 )); // Dimension index
1922- SmallVector<int64_t , 2 > inputMatrixShape = {inputShape[0 ], 1 };
1923- Type inputMatrixType = inputType.getWithSizesAndDtype (
1924- inputMatrixShape, inputType.getOptionalDtype ());
1929+ Value one = rewriter.create <Torch::ConstantIntOp>(
1930+ loc, rewriter.getI64IntegerAttr (1 )); // Dimension index
1931+ SmallVector<int64_t , 2 > inputMatrixShape = {inputShape[0 ], 1 };
1932+ Type inputMatrixType = inputType.getWithSizesAndDtype (
1933+ inputMatrixShape, inputType.getOptionalDtype ());
19251934
1926- Value inputMatrix =
1927- rewriter.create <AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
1935+ Value inputMatrix =
1936+ rewriter.create <AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
19281937
1929- Value zero = rewriter.create <Torch::ConstantIntOp>(
1930- loc, rewriter.getI64IntegerAttr (0 ));
1931- SmallVector<int64_t , 2 > vec2MatrixShape = {1 , vec2Shape[0 ]};
1932- Type vec2MatrixType = vec2Type.getWithSizesAndDtype (
1933- vec2MatrixShape, vec2Type.getOptionalDtype ());
1934-
1935- Value vec2Matrix =
1936- rewriter.create <AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);
1938+ Value zero = rewriter.create <Torch::ConstantIntOp>(
1939+ loc, rewriter.getI64IntegerAttr (0 ));
1940+ SmallVector<int64_t , 2 > vec2MatrixShape = {1 , vec2Shape[0 ]};
1941+ Type vec2MatrixType = vec2Type.getWithSizesAndDtype (
1942+ vec2MatrixShape, vec2Type.getOptionalDtype ());
19371943
1938- rewriter.replaceOpWithNewOp <AtenMatmulOp>(op, opType, inputMatrix,
1939- vec2Matrix);
1940- return success ();
1941- } else {
1942- return failure ();
1943- }
1944+ Value vec2Matrix =
1945+ rewriter.create <AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);
19441946
1947+ rewriter.replaceOpWithNewOp <AtenMatmulOp>(op, opType, inputMatrix,
1948+ vec2Matrix);
19451949 return success ();
19461950 }
19471951};
0 commit comments