Skip to content

Commit

Permalink
add forward+backward test
Browse files Browse the repository at this point in the history
  • Loading branch information
bosko-syrmia committed Sep 20, 2024
1 parent d0d70a4 commit e70be44
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 11 deletions.
94 changes: 90 additions & 4 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3563,7 +3563,7 @@ class DecomposeAtenRreluWithNoiseOp
Value noise = op.getNoise();
Value lower = op.getLower();
Value upper = op.getUpper();
auto resType = cast<BaseTensorType>(self.getType());
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Expand Down Expand Up @@ -3609,13 +3609,12 @@ class DecomposeAtenRreluWithNoiseOp
rewriter.getI1Type());
Value oneTensor =
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
Value not_positive = rewriter.create<AtenLeScalarOp>(
Value not_positive = rewriter.create<AtenLtScalarOp>(
loc, boolResType, self, constantZeroFloat);
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
noise, oneTensor);
alpha, oneTensor);
} else {
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
noise = alpha;
}

Value negativeOutput =
Expand All @@ -3628,6 +3627,93 @@ class DecomposeAtenRreluWithNoiseOp
};
} // namespace

// namespace {
// class DecomposeAtenRreluWithNoiseOp
// : public OpRewritePattern<AtenRreluWithNoiseOp> {
// public:
// using OpRewritePattern::OpRewritePattern;
// LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
// PatternRewriter &rewriter) const override {
// Location loc = op.getLoc();
// Value self = op.getSelf();
// Value noise = op.getNoise();
// Value lower = op.getLower();
// Value upper = op.getUpper();
// auto resType = cast<BaseTensorType>(op.getType());
// if (!resType.hasDtype()) {
// return rewriter.notifyMatchFailure(op, "result should have dtype");
// }

// bool training;
// if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
// return rewriter.notifyMatchFailure(op, "training should be a
// constant");
// }

// Value constantZeroFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
// Value constantOneFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
// Value constantTwoFloat =
// rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));

// // Value alpha;
// // if (training) {
// // Value none = rewriter.create<ConstantNoneOp>(loc);
// // Value emptyTensor = rewriter.create<AtenFullLikeOp>(
// // loc, resType, self, constantZeroFloat, /*dtype=*/none,
// // /*layout=*/none,
// // /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
// // alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
// // /*from=*/lower, /*to=*/upper,
// // /*generator=*/none);
// // } else {
// // Value half = rewriter.create<AtenAddOp>(loc,
// constantTwoFloat.getType(),
// // lower, upper);
// // alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(),
// half,
// // constantTwoFloat);
// // }

// Value zeroTensor =
// createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
// Value positiveOutput =
// rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);

// Value scaledSelf;
// if (training) {
// scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self,
// noise); auto boolResType =
// resType.getWithSizesAndDtype(resType.getSizes(),
// rewriter.getI1Type());
// Value oneTensor =
// createRank0Tensor(rewriter, loc, resType, constantOneFloat);
// Value not_positive = rewriter.create<AtenLeScalarOp>(
// loc, boolResType, self, constantZeroFloat);
// noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
// noise, oneTensor);
// } else {
// Value half = rewriter.create<AtenAddOp>(loc,
// constantTwoFloat.getType(),
// lower, upper);
// Value alpha = rewriter.create<AtenDivOp>(loc,
// constantTwoFloat.getType(), half,
// constantTwoFloat);
// scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self,
// alpha);
// }

// Value negativeOutput =
// rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
// Value rreluOutput = rewriter.create<AtenAddTensorOp>(
// loc, resType, positiveOutput, negativeOutput, constantOneFloat);
// rewriter.replaceOp(op, rreluOutput);
// return success();
// }
// };
// } // namespace

// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
Expand Down
38 changes: 38 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,41 @@ def forward(self, grad, input, noise):
@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule())
def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5))


class RreluWithNoiseForwardBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, grad, input, noise):
torch.ops.aten.rrelu_with_noise(
input, noise, lower=0.4, upper=0.6, training=True
)
res = torch.ops.aten.rrelu_with_noise_backward(
grad,
input,
noise,
lower=0.4,
upper=0.6,
training=True,
self_is_result=False,
)
return torch.mean(res), torch.std(res)


@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule())
def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(256, 244),
tu.rand(256, 244, low=-1.0, high=1.0),
tu.rand(256, 244, low=0.4, high=0.6),
)
Original file line number Diff line number Diff line change
Expand Up @@ -1188,13 +1188,13 @@ def __init__(self):
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
)
def forward(self, x, noise):
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True)
return torch.mean(res), torch.std(res)


@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule())
def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1536), torch.zeros((1024, 1536)))
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))


# ==============================================================================
Expand All @@ -1206,16 +1206,16 @@ def __init__(self):

@export
@annotate_args(
[None, ([1024, 1536], torch.float32, True), ([1024, 1536], torch.float32, True)]
[None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)]
)
def forward(self, x, noise):
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.1, 0.9, True)
res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True)
return torch.mean(res), torch.std(res)


@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule())
def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1536), torch.zeros((1024, 1536)))
module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128))


# ==============================================================================
Expand All @@ -1236,7 +1236,7 @@ def forward(self, x, noise):

@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule())
def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1), torch.zeros((5, 3)))
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))


# ==============================================================================
Expand All @@ -1255,7 +1255,7 @@ def forward(self, x, noise):

@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule())
def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1), torch.zeros((5, 3)))
module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3))


# ==============================================================================
Expand Down

0 comments on commit e70be44

Please sign in to comment.