From 61132380fee28c5e0e467739c55a54e7c6e85dcc Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Tue, 23 Apr 2024 16:12:43 +0000 Subject: [PATCH 1/8] [stablehlo] Support aten.any and aten.all lowering --- lib/Conversion/TorchToStablehlo/Reduction.cpp | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index c525c8b40de5..af3d2450a579 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -89,6 +89,18 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } + if (isa(op)) { + auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + + if (isa(op)) { + auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); return nullptr; @@ -448,6 +460,150 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +// AtenAllOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenAllOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenAllOp to StableHLO"); + } + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (inputElemTy != outTy.getElementType()) { + // Use output bool type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + inputElemTy = inputTy.getElementType(); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value allResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), allResult); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResults()); + return success(); +} +} // namespace + +// AtenAnyOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenAnyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenAllOp to StableHLO"); + } + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (inputElemTy != outTy.getElementType()) { + // Use output bool type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + inputElemTy = inputTy.getElementType(); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value anyResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), anyResult); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResults()); + return success(); +} +} // namespace + // AtenMaxOp namespace { template <> @@ -957,6 +1113,8 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); From 6ee34f658a44dda0dab8d9fbbe618377c5536761 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Tue, 23 Apr 2024 17:02:48 +0000 Subject: [PATCH 2/8] pass linalg and add test --- lib/Conversion/TorchToLinalg/Reduction.cpp | 24 +++- projects/pt1/e2e_testing/xfail_sets.py | 6 + .../test_suite/reduction.py | 114 ++++++++++++++++++ 3 files changed, 138 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index bd8b1fc6bfb1..6d58b0824b0a 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -341,10 +341,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, isa(op)) return b.create(loc, b.getZeroAttr(elementType)); - if (isa(op)) { + if (isa(op)) { return b.create(loc, b.getBoolAttr(true)); } + if (isa(op)) { + return b.create(loc, b.getBoolAttr(false)); + } + op->emitError("unimplemented lowering in createInitElementForReduceOp"); return nullptr; } @@ -439,11 +443,16 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); - } else if (isa(op)) { + } else if (isa(op)) { + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + return b.create(loc, self, result); + } else if (isa(op)) { Value elem = payloadArgs[0]; Value result = payloadArgs[1]; Value self = convertScalarToDtype(b, loc, elem, resultElementType); - return b.create(loc, self, result); + return b.create(loc, self, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; @@ -510,12 +519,13 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa(op)) { + if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the - // dimensions of the input tensor. + // `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each + // reduce along all the dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); @@ -714,6 +724,8 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 80ab03566bf0..5a860c83e55e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1223,6 +1223,12 @@ "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", + "ReduceAllFloat_basic", + "ReduceAllInt_basic", + "ReduceAllBool_basic", + "ReduceAnyFloat_basic", + "ReduceAnyInt_basic", + "ReduceAnyBool_basic", "ReduceAmaxMultiDim_basic", "ReduceAmaxOutOfOrderDim_basic", "ReduceAmaxSingleDim_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index b6178221eb48..1f4fa877dbdb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -68,6 +68,120 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAllFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllFloatModule()) +def ReduceAllFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceAllIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllIntModule()) +def ReduceAllIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + +# ============================================================================== + +class ReduceAllBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllBoolModule()) +def ReduceAllBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=2).to(torch.bool)) + +# ============================================================================== + +class ReduceAnyFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyFloatModule()) +def ReduceAnyFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceAnyIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyIntModule()) +def ReduceAnyIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + +# ============================================================================== + +class ReduceAnyBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyBoolModule()) +def ReduceAnyBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=2).to(torch.bool)) + +# ============================================================================== + class ReduceSumDimIntListFloatModule(torch.nn.Module): def __init__(self): super().__init__() From ed97da01ab7e04c5ae2f9e99edbc361d6774b568 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Tue, 23 Apr 2024 17:14:17 +0000 Subject: [PATCH 3/8] tosa pass --- projects/pt1/e2e_testing/xfail_sets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 5a860c83e55e..12df434495ac 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1801,6 +1801,8 @@ "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", + "ReduceAllBoolModule_basic", + "ReduceAnyBoolModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", "ReduceSumDimIntListKeepDimFloatModule_basic", From e9e8b8573e9ecfebaae60b964ef960eba1df82e3 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Tue, 23 Apr 2024 17:43:53 +0000 Subject: [PATCH 4/8] modify test name --- projects/pt1/e2e_testing/xfail_sets.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 12df434495ac..ae69bcd7e3cf 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1223,12 +1223,12 @@ "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", - "ReduceAllFloat_basic", - "ReduceAllInt_basic", - "ReduceAllBool_basic", - "ReduceAnyFloat_basic", - "ReduceAnyInt_basic", - "ReduceAnyBool_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAllBoolModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceAnyBoolModule_basic", "ReduceAmaxMultiDim_basic", "ReduceAmaxOutOfOrderDim_basic", "ReduceAmaxSingleDim_basic", From 4e69a71459e2450d321bd0e7ef78eddd3f9671c9 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Wed, 24 Apr 2024 03:03:08 +0000 Subject: [PATCH 5/8] fail onnx --- projects/pt1/e2e_testing/xfail_sets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ae69bcd7e3cf..d918af6e4e14 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2696,6 +2696,7 @@ "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", + "ReduceAnyFloatModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", } From 36735800f7311254b53670c46cb367683d49c0e7 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Wed, 24 Apr 2024 06:16:23 +0000 Subject: [PATCH 6/8] fix conflict --- .../test_suite/reduction.py | 112 ++++++++++-------- 1 file changed, 62 insertions(+), 50 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index dd07777e9068..1fac4d986399 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -68,11 +68,7 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): # ============================================================================== -<<<<<<< HEAD -class ReduceAllFloatModule(torch.nn.Module): -======= class ReduceProdFloatModule(torch.nn.Module): ->>>>>>> upstream/main def __init__(self): super().__init__() @@ -82,49 +78,22 @@ def __init__(self): ([-1, -1, -1], torch.float32, True), ]) def forward(self, a): -<<<<<<< HEAD - return torch.ops.aten.all(a) - - -@register_test_case(module_factory=lambda: ReduceAllFloatModule()) -def ReduceAllFloatModule_basic(module, tu: TestUtils): -======= return torch.prod(a) @register_test_case(module_factory=lambda: ReduceProdFloatModule()) def ReduceProdFloatModule_basic(module, tu: TestUtils): ->>>>>>> upstream/main module.forward(tu.rand(3, 4, 5)) # ============================================================================== -<<<<<<< HEAD -class ReduceAllIntModule(torch.nn.Module): -======= class ReduceProdDtypeFloatModule(torch.nn.Module): ->>>>>>> upstream/main def __init__(self): super().__init__() @export @annotate_args([ None, -<<<<<<< HEAD - ([-1, -1, -1], torch.int32, True), - ]) - def forward(self, a): - return torch.ops.aten.all(a) - - -@register_test_case(module_factory=lambda: ReduceAllIntModule()) -def ReduceAllIntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) - -# ============================================================================== - -class ReduceAllBoolModule(torch.nn.Module): -======= ([-1, -1, -1], torch.float64, True), ]) def forward(self, a): @@ -137,26 +106,80 @@ def ReduceProdDtypeFloatModule_basic(module, tu: TestUtils): # ============================================================================== class ReduceProdElementTypeBoolModule(torch.nn.Module): ->>>>>>> upstream/main def __init__(self): super().__init__() @export @annotate_args([ None, -<<<<<<< HEAD + ([-1, -1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.prod(a) + + +@register_test_case(module_factory=lambda: ReduceProdElementTypeBoolModule()) +def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) + +# ============================================================================== +class ReduceAllFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllFloatModule()) +def ReduceAllFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceAllIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllIntModule()) +def ReduceAllIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + +# ============================================================================== + +class ReduceAllBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, ([-1, -1], torch.bool, True), ]) def forward(self, a): return torch.ops.aten.all(a) - + @register_test_case(module_factory=lambda: ReduceAllBoolModule()) def ReduceAllBoolModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, high=2).to(torch.bool)) # ============================================================================== - + class ReduceAnyFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -175,7 +198,7 @@ def ReduceAnyFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) # ============================================================================== - + class ReduceAnyIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -187,14 +210,14 @@ def __init__(self): ]) def forward(self, a): return torch.ops.aten.any(a) - + @register_test_case(module_factory=lambda: ReduceAnyIntModule()) def ReduceAnyIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) # ============================================================================== - + class ReduceAnyBoolModule(torch.nn.Module): def __init__(self): super().__init__() @@ -206,22 +229,11 @@ def __init__(self): ]) def forward(self, a): return torch.ops.aten.any(a) - + @register_test_case(module_factory=lambda: ReduceAnyBoolModule()) def ReduceAnyBoolModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, high=2).to(torch.bool)) -======= - ([-1, -1, -1], torch.bool, True), - ]) - def forward(self, a): - return torch.prod(a) - - -@register_test_case(module_factory=lambda: ReduceProdElementTypeBoolModule()) -def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) ->>>>>>> upstream/main # ============================================================================== From c21261c2c4bac0b458c61036795f978375a5a226 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Wed, 24 Apr 2024 06:18:22 +0000 Subject: [PATCH 7/8] fix --- projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 1fac4d986399..076dd4e458a4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -123,6 +123,7 @@ def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) # ============================================================================== + class ReduceAllFloatModule(torch.nn.Module): def __init__(self): super().__init__() From 445985393495510128695bc435e48927a2066009 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Wed, 24 Apr 2024 06:21:08 +0000 Subject: [PATCH 8/8] lint --- lib/Conversion/TorchToLinalg/Reduction.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 56ea622be00c..a5238c9b1211 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -520,13 +520,12 @@ class ConvertReductionOp : public ConversionPattern { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; if (isa( - op)) { + AtenNormScalarOp>(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and `AtenMinOp` each - // reduce along all the dimensions of the input tensor. + // `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and + // `AtenMinOp` each reduce along all the dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i);