Skip to content

Commit

Permalink
[torch] Basic support for per-channel quantized graphs (#3623)
Browse files Browse the repository at this point in the history
This patch adds basic support for lowering graphs with per-channel
quantization. Per-channel quantized ops have to be excluded from
`FuseQuantizedOps` for now but can be used in QDQ quantized form.

Using this patch, we're able to import and execute (on the linalg
backend) graphs with per-channel quantization applied using the "new"
PyTorch 2.0 Export Quantization.
  • Loading branch information
ubfx authored Aug 10, 2024
1 parent 44266ab commit 0314188
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 35 deletions.
4 changes: 4 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2350,6 +2350,10 @@ class ConvertDequantizePerChannel
} else if (zeropointDTy.isSignedInteger(8)) {
zeropoint =
b.create<arith::ExtSIOp>(loc, b.getI32Type(), zeropoint);
} else if (zeropointDTy.isInteger(64)) {
zeropoint =
b.create<arith::TruncIOp>(loc, b.getI32Type(), zeropoint);
op->emitWarning() << "truncated zero point from 64 to 32 bit";
}

Value sub = rewriter.create<arith::SubIOp>(loc, operand, zeropoint);
Expand Down
77 changes: 48 additions & 29 deletions lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ bool isQCommutingOp(mlir::Operation *op) {
op);
}

struct QuantizedChain {
std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd, scale, zeroPoint;
};

// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... ->
// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops
Expand All @@ -58,10 +63,8 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern<SrcOp> {

LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {

mlir::Location loc = op.getLoc();
llvm::SmallVector<Value> operands(op->getOperands());
bool dequanted = false;

// Prevent fusion for 1d convolution ops and just do it as an f32 conv since
// there isn't a linalg named op for quantized 1-d convolution yet.
Expand All @@ -72,51 +75,70 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern<SrcOp> {
return rewriter.notifyMatchFailure(
op, "1-d quantized convolution is not supported");

SmallVector<QuantizedChain, 2> operandChains;
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
Value operand = operands[i];
std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd, scale, zeroPoint;
QuantizedChain chain;
for (unsigned k = 0; k < depth + 1; k++) {
auto currOp = operand.getDefiningOp();
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
if (!currOp)
break;
// Case 1 : currOp is a q commuting op (continue loop)
if (isQCommutingOp(currOp)) {
commutingOpStack.push(currOp);
chain.commutingOpStack.push(currOp);
// set operand to currOp for next k-iteration
operand = currOp->getOperand(0);
continue;
}
// Case 2 : currOp is a dequant op (end loop)
if (llvm::isa<AtenDequantizeSelfOp, AtenDequantizeTensorOp>(currOp)) {
dequantOpd = currOp->getOperand(0);
chain.dequantOpd = currOp->getOperand(0);
// Bail out if any operand is per-channel quantized, which would
// require more complex fusion logic.
if (llvm::isa<Aten_MakePerChannelQuantizedTensorOp>(
chain.dequantOpd.getDefiningOp()))
break;

auto MPTQTOp =
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
MPTQTOpd = MPTQTOp.getOperand(0);
scale = MPTQTOp.getOperand(1);
zeroPoint = MPTQTOp.getOperand(2);
chain.dequantOpd
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
chain.MPTQTOpd = MPTQTOp.getOperand(0);
chain.scale = MPTQTOp.getOperand(1);
chain.zeroPoint = MPTQTOp.getOperand(2);
}
// either a dequant was found or chain broken, so break loop
break;
}

// move to next operand if this trace was unsuccessful
if (!MPTQTOpd)
continue;
// if tracing this operand was successful, add it to operandChains.
if (chain.MPTQTOpd)
operandChains.push_back(std::move(chain));
}

// a successful trace occured, so set dequant to true
dequanted = true;
// Continuing the rewriting with only some of the operandsToQuantize traced
// successfully is possible but leads to "half-quantized" ops which are
// expected to cause problems in later lowering steps. We opt out of
// treating these cases for now.
if (operandChains.size() !=
std::size(QuantInfo<SrcOp>::operandsToQuantize)) {
if (!operandChains.empty())
op.emitWarning("Partially traced quantized operands. This op will "
"remain in QDQ form.");
return rewriter.notifyMatchFailure(
op, "did not find a complete quantized chain for all operands");
}

for (auto &&[i, chain] : llvm::enumerate(operandChains)) {
// rewrite stack
Value oldOpd = MPTQTOpd;
Value oldOpd = chain.MPTQTOpd;
Type intDType =
cast<ValueTensorType>(MPTQTOpd.getType()).getOptionalDtype();
while (!commutingOpStack.empty()) {
cast<ValueTensorType>(chain.MPTQTOpd.getType()).getOptionalDtype();
while (!chain.commutingOpStack.empty()) {
// get front of the commuting op stack and replace its first operand
// with oldOpd
auto currOp = commutingOpStack.top();
commutingOpStack.pop();
auto currOp = chain.commutingOpStack.top();
chain.commutingOpStack.pop();
llvm::SmallVector<Value> currOperands(currOp->getOperands());
currOperands[0] = oldOpd;
// pad ops aren't quite commuting, so we include some extra logic to
Expand All @@ -125,14 +147,15 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern<SrcOp> {
Value floatPadValue = currOperands.back();
Value quantPadValue;
if (isa<Torch::NoneType>(floatPadValue.getType()))
quantPadValue = rewriter.create<AtenFloatScalarOp>(loc, zeroPoint);
quantPadValue =
rewriter.create<AtenFloatScalarOp>(loc, chain.zeroPoint);
else {
floatPadValue =
rewriter.create<AtenFloatScalarOp>(loc, floatPadValue);
quantPadValue = rewriter.create<Torch::AtenDivFloatOp>(
loc, floatPadValue, scale);
loc, floatPadValue, chain.scale);
quantPadValue = rewriter.create<Torch::AtenAddFloatIntOp>(
loc, quantPadValue, zeroPoint);
loc, quantPadValue, chain.zeroPoint);
}
// clamp pad value to qint range
if (auto intType = dyn_cast<mlir::IntegerType>(intDType)) {
Expand Down Expand Up @@ -175,19 +198,15 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern<SrcOp> {
// stack is empty, so oldOpd is now the corrected verion of the
// SrcOp's original operand
// convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp
auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands();
auto MPTQTOperands = chain.dequantOpd.getDefiningOp()->getOperands();
auto qTorchType =
cast<ValueTensorType>(dequantOpd.getType()).getOptionalDtype();
cast<ValueTensorType>(chain.dequantOpd.getType()).getOptionalDtype();
auto newMPTQTType = rewriter.getType<ValueTensorType>(
cast<ValueTensorType>(operands[i].getType()).getSizes(), qTorchType);
operands[i] = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]);
}

if (!dequanted) {
return rewriter.notifyMatchFailure(op, "No dequantizations found.");
}

rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
Expand Down
37 changes: 31 additions & 6 deletions lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ class MatchQuantizeOperator : public OpRewritePattern<OperatorOp> {
return success();
}

if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") {
auto clamp = rewriter.create<AtenClampOp>(
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
op.getOperand(3), op.getOperand(4));
auto prepareDequantize = [&](Value quantMin, Value quantMax, Value &clamp,
Type &qTy) {
clamp =
rewriter.create<AtenClampOp>(op.getLoc(), op.getOperand(0).getType(),
op.getOperand(0), quantMin, quantMax);

auto clampTy = cast<Torch::ValueTensorType>(clamp.getType());
if (!clampTy.hasDtype())
Expand All @@ -75,15 +76,39 @@ class MatchQuantizeOperator : public OpRewritePattern<OperatorOp> {
return rewriter.notifyMatchFailure(op,
"dequantization has unknown qtype");

Type qTy = Torch::ValueTensorType::get(
op.getContext(), clampTy.getOptionalSizes(), qetype);
qTy = Torch::ValueTensorType::get(op.getContext(),
clampTy.getOptionalSizes(), qetype);
return success();
};

if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") {
Value clamp;
Type qTy;
if (failed(prepareDequantize(op.getOperand(3), op.getOperand(4), clamp,
qTy)))
return failure();

auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2));
rewriter.replaceOpWithNewOp<AtenDequantizeTensorOp>(
op, op.getResultTypes(), quant);
return success();
}

if (op.getName() == "torch.quantized_decomposed.dequantize_per_channel") {
Value clamp;
Type qTy;
if (failed(prepareDequantize(op.getOperand(4), op.getOperand(5), clamp,
qTy)))
return failure();
auto quant = rewriter.create<Aten_MakePerChannelQuantizedTensorOp>(
op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2),
op.getOperand(3));
rewriter.replaceOpWithNewOp<AtenDequantizeSelfOp>(op, op.getResultTypes(),
quant);
return success();
}

return failure();
}
};
Expand Down
21 changes: 21 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",
Expand Down Expand Up @@ -382,6 +385,9 @@
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
Expand Down Expand Up @@ -550,6 +556,9 @@
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
Expand Down Expand Up @@ -2224,6 +2233,9 @@
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"ConvTranspose2DQInt8_basic",
}

Expand Down Expand Up @@ -2374,6 +2386,9 @@
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingModule_basic",
"Conv3dModule_basic",
Expand Down Expand Up @@ -2953,6 +2968,9 @@
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Conv3dModule_basic",
Expand Down Expand Up @@ -3748,6 +3766,9 @@
"Conv2dQInt8Module_depthwise",
"Conv2dQInt8Module_grouped",
"Conv2dQInt8Module_not_depthwise",
"Conv2dQInt8PerChannelModule_basic",
"Conv2dQInt8PerChannelModule_depthwise",
"Conv2dQInt8PerChannelModule_grouped",
"Conv2dWithPaddingDilationStrideModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
Expand Down
90 changes: 90 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,96 @@ def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
)


class Conv2dQInt8PerChannelModuleBase(torch.nn.Module):
def __init__(self, groups=1):
self.groups = groups
super().__init__()

def _forward(self, inputVec, weight, scales, zeropoints, bias):
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
inputVec = torch.dequantize(inputVec)

weight = torch._make_per_channel_quantized_tensor(
weight, scales, zeropoints, axis=0
)
weight = torch.dequantize(weight)

bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
bias = torch.dequantize(bias)

return torch.ops.aten.conv2d(
inputVec,
weight,
bias=bias,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
groups=self.groups,
)


class Conv2dQInt8PerChannelModuleDyn(Conv2dQInt8PerChannelModuleBase):
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
([-1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, inputVec, weight, scales, zeropoints, bias):
return self._forward(inputVec, weight, scales, zeropoints, bias)


class Conv2dQInt8PerChannelModuleStatic(Conv2dQInt8PerChannelModuleBase):
@export
@annotate_args(
[
None,
([2, 3, 12, 12], torch.int8, True),
([3, 1, 5, 3], torch.int8, True),
([3], torch.float, True),
([3], torch.int8, True),
([3], torch.float, True),
]
)
def forward(self, inputVec, weight, scales, zeropoints, bias):
return self._forward(inputVec, weight, scales, zeropoints, bias)


@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleDyn())
def Conv2dQInt8PerChannelModule_basic(module, tu: TestUtils):
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
scales = tu.rand(3)
zeropoints = tu.rand(3).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, scales, zeropoints, bias)


@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleDyn(groups=2))
def Conv2dQInt8PerChannelModule_grouped(module, tu: TestUtils):
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
scales = tu.rand(6)
zeropoints = tu.rand(6).to(torch.int8)
bias = torch.rand(6)
module.forward(inputVec, weight, scales, zeropoints, bias)


@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleStatic(groups=3))
def Conv2dQInt8PerChannelModule_depthwise(module, tu: TestUtils):
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
scales = tu.rand(3)
zeropoints = tu.rand(3).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, scales, zeropoints, bias)


# torchvision.deform_conv2d

import torchvision
Expand Down
Loading

0 comments on commit 0314188

Please sign in to comment.