Skip to content

Commit

Permalink
[Torch] Eliminate getWithLeastStaticInformation in DecomposeAtenLinsp…
Browse files Browse the repository at this point in the history
…aceOp and DecomposeAtenFakeQuantizePerTensorAffineOp (#3539)

as title
  • Loading branch information
Xinyu Yang authored Jul 15, 2024
1 parent fe9db78 commit e5d1677
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
42 changes: 28 additions & 14 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7592,7 +7592,6 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {
Location loc = op.getLoc();
MLIRContext *context = getContext();

auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value none = rewriter.create<ConstantNoneOp>(loc);
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value zero =
Expand All @@ -7602,13 +7601,25 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {

Value addStart;
int64_t steps;
auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true);
auto fp32Type = rewriter.getF32Type();
auto arangeIntType =
getTensorTypeFromShapeValues({op.getSteps()}, si64Type);
auto arangeFp32Type =
getTensorTypeFromShapeValues({op.getSteps()}, fp32Type);
if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) {
// specically handle steps == 1
Value arange = rewriter.create<AtenArangeStartOp>(
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
op.getDevice(), op.getPinMemory());
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, arange,
op.getStart(), one);
loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none,
op.getLayout(), op.getDevice(), op.getPinMemory());
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
isa<Torch::FloatType>(op.getStart().getType())) {
addStart = rewriter.create<AtenAddScalarOp>(loc, arangeFp32Type, arange,
op.getStart(), one);
} else {
addStart = rewriter.create<AtenAddScalarOp>(loc, arangeIntType, arange,
op.getStart(), one);
}
} else {
// handle steps != 1 or dynamic steps
Value neOrNot = rewriter.create<AtenNeIntOp>(loc, op.getSteps(), one);
Expand All @@ -7617,8 +7628,8 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {
rewriter.getStringAttr("linspace's dynamic steps must not be 1"));
// create arange: [0, ..., steps - 1]
Value arange = rewriter.create<AtenArangeStartOp>(
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
op.getDevice(), op.getPinMemory());
loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none,
op.getLayout(), op.getDevice(), op.getPinMemory());
// calculate (end - start) / (steps - 1)
Value sub;
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
Expand All @@ -7632,15 +7643,16 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {
loc, sub, rewriter.create<AtenSubIntOp>(loc, op.getSteps(), one));
// calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start
Value mulScalar =
rewriter.create<AtenMulScalarOp>(loc, baseType, arange, div);
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, mulScalar,
op.getStart(), one);
rewriter.create<AtenMulScalarOp>(loc, arangeFp32Type, arange, div);
addStart = rewriter.create<AtenAddScalarOp>(
loc, arangeFp32Type, mulScalar, op.getStart(), one);
}
// to dtype
Value result;
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
loc, op.getType(), addStart, op.getDtype(),
/*non_blocking=*/falseVal,
/*copy=*/falseVal, /*memory_format=*/none);
} else {
Value f32Type = rewriter.create<ConstantIntOp>(
Expand Down Expand Up @@ -8557,7 +8569,6 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);

// input/scale
Value divScale = rewriter.create<AtenDivScalarOp>(
Expand All @@ -8568,16 +8579,19 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp
Value addZeroPoint = rewriter.create<AtenAddScalarOp>(
loc, op.getType(), round, op.getZeroPoint(), one);
// max(quant_min, std::nearby_int(input/scale) + zero_point)
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto tensorIntType =
ValueTensorType::get(context, ArrayRef<int64_t>{1}, si64Type);
Value max = rewriter.create<AtenMaximumOp>(
loc, op.getType(), addZeroPoint,
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMin(),
rewriter.create<AtenTensorIntOp>(loc, tensorIntType, op.getQuantMin(),
/*dtype=*/none,
/*device=*/none,
/*requires_grad=*/falseVal));
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
Value min = rewriter.create<AtenMinimumOp>(
loc, op.getType(), max,
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMax(),
rewriter.create<AtenTensorIntOp>(loc, tensorIntType, op.getQuantMax(),
/*dtype=*/none, /*device=*/none,
/*requires_grad=*/falseVal));
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
Expand Down
7 changes: 0 additions & 7 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,6 @@
"ElementwiseRreluTrainStaticModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"EqIntModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"FloatImplicitModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
Expand Down Expand Up @@ -597,9 +593,6 @@
"ElementwiseToDtypeI64ToUI8Module_basic",
"EmptyModule_uint8",
"EqIntModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",
Expand Down

0 comments on commit e5d1677

Please sign in to comment.