Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yyp0 committed Sep 24, 2024
1 parent a07e372 commit 3fbbb22
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
1 change: 0 additions & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7329,7 +7329,6 @@ class DecomposeAtenAdaptiveMaxPool1dOp
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);

int64_t outputSizeInt;
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,9 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)")
emit(
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
)
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
emit(
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,7 @@ def forward(self, x):
def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 10))


class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -1797,6 +1798,7 @@ def forward(self, x):
def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 512, 7))


# AdaptiveMaxPool2d


Expand Down

0 comments on commit 3fbbb22

Please sign in to comment.