Skip to content

Commit

Permalink
adjust test
Browse files Browse the repository at this point in the history
  • Loading branch information
bosko-syrmia committed Sep 27, 2024
1 parent e70be44 commit 210e5e8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 95 deletions.
87 changes: 0 additions & 87 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3627,93 +3627,6 @@ class DecomposeAtenRreluWithNoiseOp
};
} // namespace

// namespace {
// class DecomposeAtenRreluWithNoiseOp
// : public OpRewritePattern<AtenRreluWithNoiseOp> {
// public:
// using OpRewritePattern::OpRewritePattern;
// LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
// PatternRewriter &rewriter) const override {
// Location loc = op.getLoc();
// Value self = op.getSelf();
// Value noise = op.getNoise();
// Value lower = op.getLower();
// Value upper = op.getUpper();
// auto resType = cast<BaseTensorType>(op.getType());
// if (!resType.hasDtype()) {
// return rewriter.notifyMatchFailure(op, "result should have dtype");
// }

// bool training;
// if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
// return rewriter.notifyMatchFailure(op, "training should be a
// constant");
// }

// Value constantZeroFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
// Value constantOneFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
// Value constantTwoFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));

// // Value alpha;
// // if (training) {
// // Value none = rewriter.create<ConstantNoneOp>(loc);
// // Value emptyTensor = rewriter.create<AtenFullLikeOp>(
// // loc, resType, self, constantZeroFloat, /*dtype=*/none,
// // /*layout=*/none,
// // /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
// // alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
// // /*from=*/lower, /*to=*/upper,
// // /*generator=*/none);
// // } else {
// // Value half = rewriter.create<AtenAddOp>(loc,
// constantTwoFloat.getType(),
// // lower, upper);
// // alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(),
// half,
// // constantTwoFloat);
// // }

// Value zeroTensor =
// createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
// Value positiveOutput =
// rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);

// Value scaledSelf;
// if (training) {
// scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self,
// noise); auto boolResType =
// resType.getWithSizesAndDtype(resType.getSizes(),
// rewriter.getI1Type());
// Value oneTensor =
// createRank0Tensor(rewriter, loc, resType, constantOneFloat);
// Value not_positive = rewriter.create<AtenLeScalarOp>(
// loc, boolResType, self, constantZeroFloat);
// noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
// noise, oneTensor);
// } else {
// Value half = rewriter.create<AtenAddOp>(loc,
// constantTwoFloat.getType(),
// lower, upper);
// Value alpha = rewriter.create<AtenDivOp>(loc,
// constantTwoFloat.getType(), half,
// constantTwoFloat);
// scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self,
// alpha);
// }

// Value negativeOutput =
// rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
// Value rreluOutput = rewriter.create<AtenAddTensorOp>(
// loc, resType, positiveOutput, negativeOutput, constantOneFloat);
// rewriter.replaceOp(op, rreluOutput);
// return success();
// }
// };
// } // namespace

// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
Expand Down
13 changes: 5 additions & 8 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,6 @@ def __init__(self):
]
)
def forward(self, grad, input, noise):
torch.ops.aten.rrelu_with_noise(
input, noise, lower=0.4, upper=0.6, training=True
)
res = torch.ops.aten.rrelu_with_noise_backward(
grad,
input,
Expand All @@ -481,8 +478,8 @@ def forward(self, grad, input, noise):

@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule())
def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(256, 244),
tu.rand(256, 244, low=-1.0, high=1.0),
tu.rand(256, 244, low=0.4, high=0.6),
)
grad = tu.rand(256, 244)
input = tu.rand(256, 244, low=-1.0, high=1.0)
noise = tu.rand(256, 244)
torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True)
module.forward(grad, input, noise)

0 comments on commit 210e5e8

Please sign in to comment.