Skip to content

Commit

Permalink
Do not try to legalize transposed convolution (#2721)
Browse files Browse the repository at this point in the history
Currently transposed convolution is not handled correctly by
`TorchToTosa`. This PR allows transposed convolutions to pass through
the conversion so that they can be handled by other conversion passes
later in a pipeline.

An example input which produces a compilation error is:

```
func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
  %true = torch.constant.bool true
  %int1 = torch.constant.int 1
  %int2 = torch.constant.int 2
  %weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32>
  %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
  %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
  %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
  %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>
  return %output : !torch.vtensor<[1,64,2,200],f32>
}
```

This MLIR produces an error about a cast operation with a size mismatch
when passed through `torch-to-tosa`:

```
 error: 'tensor.cast' op operand type 'tensor<1x64x1x50xf32>' and result type 'tensor<1x64x2x200xf32>' are cast incompatible
```

---------

Co-authored-by: Srinath Avadhanula <[email protected]>
  • Loading branch information
srinathava and Srinath Avadhanula authored Jan 22, 2024
1 parent b9806cf commit 73b3060
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
8 changes: 8 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,14 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

bool transposed;
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(
op, "Unimplemented: non-constant value for transposed not supported");
if (transposed)
return rewriter.notifyMatchFailure(
op, "Unimplemented: transposed convolution not supported");

auto input = adaptor.getInput();
auto weight = adaptor.getWeight();

Expand Down
18 changes: 18 additions & 0 deletions test/Conversion/TorchToTosa/conv2d_transpose.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics

// The following test ensures that a tranposed convolution op is not
// lowered in the torch-to-tosa conversion pass.

func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
%true = torch.constant.bool true
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32>
%bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}}
%output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>
return %output : !torch.vtensor<[1,64,2,200],f32>
}

0 comments on commit 73b3060

Please sign in to comment.