Skip to content

Commit

Permalink
Linalg lowering for aten.conv2d(bias=True)
Browse files Browse the repository at this point in the history
Previously aten.conv2d was only lowered if there was no bias.
Here lowering is extended to support bias.
  • Loading branch information
ljfitz authored and silvasean committed Dec 8, 2021
1 parent c598e01 commit 2414bdb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
22 changes: 22 additions & 0 deletions e2e_testing/torchscript/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
module.forward(t)


class Conv2dBiasNoPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(2, 10, 3, bias=True)
self.train(False)

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.conv(x)


@register_test_case(module_factory=lambda: Conv2dBiasNoPaddingModule())
def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)


class Conv2dWithPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
44 changes: 33 additions & 11 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,6 @@ class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
if (!op.bias().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "only support None bias");

Value c1 =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
Expand Down Expand Up @@ -473,22 +470,47 @@ class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
castIndexToInt(weightW), strideIntValues[1]);

Value c0float = rewriter.create<arith::ConstantOp>(
loc,
FloatAttr::get(
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, F, Hout, Wout}, elementType);
Value initTensor0 =
rewriter.create<linalg::FillOp>(loc, c0float, initTensor).getResult(0);

Value bias = adaptor.bias();
Value biasInitTensor;
if (bias.getType().isa<Torch::NoneType>()) {
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
biasInitTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
} else {
auto biasType = bias.getType().cast<RankedTensorType>();
if (biasType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
if (elementType != biasType.getElementType())
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");

auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
SmallVector<AffineMap> indexingMaps = {
// bias is used to initialize the channels - dimension 1 of output
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(1), context),
rewriter.getMultiDimIdentityMap(resultRank)};
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
biasInitTensor = rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
}

auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Value conv2d =
rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, initTensor0.getType(), ValueRange{paddedInput, weight},
initTensor0, stridesAttr, dilationAttr)
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
biasInitTensor, stridesAttr, dilationAttr)
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
Expand Down

0 comments on commit 2414bdb

Please sign in to comment.