Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.cond operator not supported on a simple example (error: failed to legalize operation 'torch.operator' that was explicitly marked illegal) #937

Open
JibAxelera opened this issue Feb 14, 2025 · 0 comments

Comments

@JibAxelera
Copy link

Problem :

I am trying to compile a dummy example of a model whose computation graph depends on the input, but it fails with the error "error: failed to legalize operation 'torch.operator' that was explicitly marked illegal"

Steps to reproduce :

Run the following code :

import torch
import torch.nn as nn
import copy

class CondNetwork(nn.Module):

    def __init__(self):
        super(CondNetwork, self).__init__()

        self.confidence_threshold = 2

    def true_fn(self):
        return torch.rand(10, dtype=torch.float32)

    def false_fn(self):
        return torch.rand(10, dtype=torch.float32)

    def forward(self, x):

        with torch.no_grad():

            condition = x.sum() > self.confidence_threshold

            return torch.cond(condition, self.true_fn, self.false_fn)

def model_export(model, device):

    model.eval()

    cond_model = copy.deepcopy(model).to(device)

    x = torch.randn(1, 3, 32, 32).to(device)

    with torch.no_grad():
        cond_model.eval()
        torch.onnx.export(cond_model, x, './conditional_model.onnx', verbose=True, dynamo=True, report= True)

###-- Main
def main():

    model = CondNetwork()

    model.cuda()

    model_export(model, device="cuda")

if __name__ == '__main__':
    main()

After running this code run :

iree-import-onnx conditional_model.onnx -o conditional_model.mlir

Then, run :

iree-compile conditional_model.mlir -o conditional_model.vmfb

which results in the following error :

conditional_model.mlir:12:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
      %5 = torch.operator "onnx.true_graph_0"() : () -> !torch.vtensor<[10],f32> 
           ^
conditional_model.mlir:12:12: note: see current operation: %16 = "torch.operator"() <{name = "onnx.true_graph_0"}> : () -> !torch.vtensor<[10],f32>

Is there something wrong with my implementation or is the operation simply not supported ?

Additional informations :

Versions of the packages :
torch : 2.6.0
iree-turbine : 3.2.0

Associated IR (file conditional_model.mlir) :

module {
  func.func @main_graph(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.opset_versions = {pkg.onnxscript.torch_lib.common = 1 : si64, pkg.torch.__subgraph__ = 1 : si64}, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.6.0+cu124"} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.ReduceSum"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.noop_with_empty_axes = 0 : si64} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[],f32> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %2 = torch.operator "onnx.Cast"(%1) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],f32> 
    %3 = torch.operator "onnx.Greater"(%0, %2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[],i1> 
    %4 = torch.operator "onnx.If"(%3) : (!torch.vtensor<[],i1>) -> !torch.vtensor<[10],f32> {
      %5 = torch.operator "onnx.false_graph_0"() : () -> !torch.vtensor<[10],f32> 
      torch.operator_terminator %5 : !torch.vtensor<[10],f32>
    }, {
      %5 = torch.operator "onnx.true_graph_0"() : () -> !torch.vtensor<[10],f32> 
      torch.operator_terminator %5 : !torch.vtensor<[10],f32>
    }
    return %4 : !torch.vtensor<[10],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _: "0x080000000200000000000000"
    }
  }
#-}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant