Skip to content
Open
12 changes: 4 additions & 8 deletions lib/Conversion/TorchToSCF/TorchToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
targetType = Torch::IntType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
} else if (auto tty = dyn_cast<RankedTensorType>(targetType)) {
targetType = op.getIterArgsInit()[barg.index()].getType();
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfWhileOp.getLoc(), targetType, {to});
}
if (!torchArg)
return rewriter.notifyMatchFailure(op,
Expand All @@ -173,14 +177,6 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
"unsupported type of the operand");
loopConditionIterArgs.push_back(shouldContinue);
for (auto torchArg : primLoopConditionOp.getIterArgs()) {
Type torchType = torchArg.getType();

// If the argument is a torch tensor, directly add it in the list of
// iter args.
if (isa<Torch::BaseTensorType>(torchType)) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
Value arg = typeConverter->materializeTargetConversion(
rewriter, scfWhileOp->getLoc(),
typeConverter->convertType(torchArg.getType()), {torchArg});
Expand Down
3 changes: 2 additions & 1 deletion projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .xfail_sets import (
LINALG_XFAIL_SET,
LINALG_CRASHING_SET,
TORCHSCRIPT_XFAIL_SET,
STABLEHLO_PASS_SET,
STABLEHLO_CRASHING_SET,
TOSA_PASS_SET,
Expand Down Expand Up @@ -167,7 +168,7 @@ def main():
crashing_set = set()
elif args.config == "torchscript":
config = TorchScriptTestConfig()
xfail_set = set()
xfail_set = TORCHSCRIPT_XFAIL_SET
crashing_set = set()
elif args.config == "lazy_tensor_core":
config = LazyTensorCoreTestConfig()
Expand Down
17 changes: 17 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@
"AtenMmInt8Types_basic",
}

TORCHSCRIPT_XFAIL_SET = {
# Compilation Error: torch.jit.frontend.UnsupportedNodeError: import statements aren't supported:
"TorchPrimLoopWhileLikeHOPModule_basic",
}

TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors
# torch._dynamo.exc.Unsupported: Tensor.item
Expand Down Expand Up @@ -246,6 +251,8 @@
"IsFloatingPointInt_False",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
# torch._dynamo.exc.BackendCompilerFailed: Unsupported op: get_attr
"TorchPrimLoopWhileLikeHOPModule_basic",
"ScalarConstantTupleModule_basic",
# END tests failing due to: empty graph in dynamo
# ERROR due to: backend never runs because of empty frame
Expand Down Expand Up @@ -481,6 +488,7 @@
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
"ThresholdBackward2dMixedModule_basic",
"TorchPrimLoopWhileLikeHOPModule_basic", # Compilation error: failed to legalize operation 'func.call'
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
Expand Down Expand Up @@ -993,6 +1001,8 @@
"ElementwiseClampMinModule_bfloat16",
"ElementwiseClampModule_bfloat16",
"ElementwiseReluModule_bfloat16",
# Runtime error: failed to legalize operation 'torch.constant.int'
"TorchPrimLoopWhileLikeHOPModule_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -2575,6 +2585,7 @@

LTC_XFAIL_SET = {
"TorchPrimLoopForLikeTensorArgModule_basic" "CollapseAllDimensionsModule_basic",
"TorchPrimLoopWhileLikeHOPModule_basic",
"CollapseRank1DynamicModule_basic",
"CollapseStaticModule_basic",
"CollapsePartialDynamicModule_basic",
Expand Down Expand Up @@ -3261,6 +3272,8 @@
"ToCopyWithDTypeModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
# RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function
"TorchPrimLoopWhileLikeHOPModule_basic",
"TraceModule_basic",
"TraceModule_empty",
"TraceModule_nonsquare",
Expand Down Expand Up @@ -3957,6 +3970,8 @@
"ThresholdBackward2dMixedModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
# Runtime error: failed to legalize operation 'torch.aten.Bool.Tensor'
"TorchPrimLoopWhileLikeHOPModule_basic",
"TraceModule_empty",
"TraceUnsignedIntModule_empty",
"TransposedConv1dNegativePadding_basic",
Expand Down Expand Up @@ -5036,6 +5051,8 @@
"ToDtypeFloatFromIntModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
# RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function
"TorchPrimLoopWhileLikeHOPModule_basic",
"TraceModule_basic",
"TraceModule_empty",
"TraceModule_nonsquare",
Expand Down
34 changes: 34 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
from torch._higher_order_ops.while_loop import while_loop

# ==============================================================================

Expand Down Expand Up @@ -78,3 +79,36 @@ def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils):
x_test = torch.zeros([7, 9]).float()

module.forward(x_test)


# ==============================================================================


class TorchPrimLoopWhileLikeHOPModule(torch.nn.Module):
def __init__(self):
super().__init__()

def body_fn(self, i, x):
return i + 1, x + 1

def cond_fn(self, i, x):
return i < 3

@export
@annotate_args(
[
None,
([7, 9], torch.float32, True),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
i0 = torch.tensor(0)
out_i, out_x = while_loop(self.cond_fn, self.body_fn, (i0, x))
return out_i, out_x


@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeHOPModule())
def TorchPrimLoopWhileLikeHOPModule_basic(module, tu: TestUtils):
x_test = torch.zeros([7, 9]).float()

module.forward(x_test)
Loading
Loading