Skip to content

Commit

Permalink
Fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy committed Apr 27, 2024
1 parent c2b81c6 commit 22bab6a
Show file tree
Hide file tree
Showing 16 changed files with 30 additions and 47 deletions.
3 changes: 1 addition & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1717,8 +1717,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
int startSize = startsTy.getDimSize(0);

auto endsTorchTy = cast<Torch::ValueTensorType>(ends.getType());
auto endsTy =
dyn_cast<RankedTensorType>(endsTorchTy.toBuiltinTensor());
auto endsTy = dyn_cast<RankedTensorType>(endsTorchTy.toBuiltinTensor());
int endSize = endsTy.getDimSize(0);
auto resultTy =
dyn_cast<RankedTensorType>(resultTorchType.toBuiltinTensor());
Expand Down
7 changes: 3 additions & 4 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
Value kHtimeskW = rewriter.create<arith::MulIOp>(
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
divisor =
isa<Torch::NoneType>(op.getDivisorOverride().getType())
? kHtimeskW
: adaptor.getDivisorOverride();
divisor = isa<Torch::NoneType>(op.getDivisorOverride().getType())
? kHtimeskW
: adaptor.getDivisorOverride();
} else {
divisor = kernelSizeIntValues[0];
}
Expand Down
6 changes: 2 additions & 4 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
auto idxResultType =
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
RankedTensorType inputType =
cast<RankedTensorType>(input.getType());
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
Type idxElementType =
getElementTypeOrSelf(typec->convertType(idxResultType));
if (!isa<IntegerType>(idxElementType))
Expand Down Expand Up @@ -480,8 +479,7 @@ class ConvertReductionOp : public ConversionPattern {

SmallVector<int64_t> dimList;
int64_t dim;
bool isNoneOrEmptyDimList =
isa<Torch::NoneType>(op.getDim().getType());
bool isNoneOrEmptyDimList = isa<Torch::NoneType>(op.getDim().getType());
if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
// Fix negative dimensions, if any, before adding to the list.
for (int64_t dim : dimList) {
Expand Down
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<OpTy> {
Value input = adaptor.getA();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputSizes.size();
Type inputDtype =
cast<BaseTensorType>(op.getA().getType()).getDtype();
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();

// The `input` tensor must contain exactly one element, i.e., either the
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
Expand Down
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
return nullptr;
}

Type elementalType =
cast<BaseTensorType>(op.getSelf().getType()).getDtype();
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
Expand Down
5 changes: 2 additions & 3 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ Value torch_to_linalg::getPaddedTensor(
SmallVectorImpl<int64_t> &lowPaddingInts,
SmallVectorImpl<int64_t> &highPaddingInts, Value pad) {
Location loc = op->getLoc();
Type rankedTensorType =
tensor::PadOp::inferResultType(cast<RankedTensorType>(input.getType()),
lowPaddingInts, highPaddingInts);
Type rankedTensorType = tensor::PadOp::inferResultType(
cast<RankedTensorType>(input.getType()), lowPaddingInts, highPaddingInts);
SmallVector<OpFoldResult> lowPaddings =
getIndexIntsAsOpFoldResult(b, lowPaddingInts);
SmallVector<OpFoldResult> highPaddings =
Expand Down
9 changes: 3 additions & 6 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,15 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<AtenOpT> {
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType =
dyn_cast<RankedTensorType>(adaptor.getA().getType());
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
if (!inputType)

op.emitError("only Tensor types supported in StableHLO");
Location loc = op.getLoc();
Value input = adaptor.getA();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputSizes.size();
Type inputDtype =
cast<BaseTensorType>(op.getA().getType()).getDtype();
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();

Value constantOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
Expand Down Expand Up @@ -1012,8 +1010,7 @@ LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
AtenScalarImplicitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Type inputDtype =
cast<BaseTensorType>(op.getA().getType()).getDtype();
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
Expand Down
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
SmallVector<Value> &strides) {
Location loc = op.getLoc();
auto input = adaptor.getSelf();
RankedTensorType inputType =
cast<RankedTensorType>(input.getType());
RankedTensorType inputType = cast<RankedTensorType>(input.getType());

Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Expand Down
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
auto biasTy = bias.getType();

// StableHLO does not mandate that elementwise op tensors need to be ranked.
if (!isa<Torch::NoneType>(biasTy) &&
!isa<RankedTensorType>(biasTy))
if (!isa<Torch::NoneType>(biasTy) && !isa<RankedTensorType>(biasTy))
return op.emitError("only ranked tensor types are supported in StableHLO "
"matmul for bias tensor");

Expand Down
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToStablehlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rankType =
dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
if (!rankType)
return op.emitError("Only ranked tensor types are currently supported");

Expand Down
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,8 +1557,7 @@ class ConvertAtenScaledDotProductAttentionOp
key = collapseBatch(key);
value = collapseBatch(value);

SmallVector<int64_t> outSizes(
cast<ShapedType>(query.getType()).getShape());
SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
SmallVector<int64_t> valueSizes(
cast<ShapedType>(value.getType()).getShape());
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
Expand Down
6 changes: 2 additions & 4 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1744,8 +1744,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
auto biasTy = bias.getType();

// TOSA does not mandate that elementwise op tensors need to be ranked.
if (!isa<Torch::NoneType>(biasTy) &&
!isa<TensorType>(biasTy))
if (!isa<Torch::NoneType>(biasTy) && !isa<TensorType>(biasTy))
return rewriter.notifyMatchFailure(
op, "Only tensor types supported in GEMM to TOSA for bias tensor");

Expand Down Expand Up @@ -4462,8 +4461,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
Value transposePoolingInputToHwc(AtenOpT op,
ConversionPatternRewriter &rewriter,
Value input) const {
auto inputRank =
cast<RankedTensorType>(input.getType()).getRank();
auto inputRank = cast<RankedTensorType>(input.getType()).getRank();

SmallVector<int32_t> nchwToNhwc4DTransposeDims({0, 2, 3, 1});
SmallVector<int32_t> chwToHwc3DTransposeDims({1, 2, 0});
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
[](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs,
Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0]isa<BaseTensorType>(.getType()));
assert(inputs[0] isa<BaseTensorType>(.getType()));
return copyTensorToType(builder, loc, type, inputs[0]);
});
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);
Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
Value slice = rewriter.create<AtenSliceTensorOp>(
loc,
computeReductionType(rewriter, op,
cast<BaseTensorType>(self.getType()), dim,
computeReductionType(rewriter, op, cast<BaseTensorType>(self.getType()),
dim,
/*keepDim=*/true),
op.getSelf(), dim, start, startPlusOne, /*step=*/one);

Expand Down Expand Up @@ -2596,8 +2596,9 @@ class DecomposeAtenStackOp : public OpRewritePattern<AtenStackOp> {
}

Type listElemType =
cast<BaseTensorType>(op.getType()).getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
cast<BaseTensorType>(op.getType())
.getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
op.getLoc(), listType, unsqueezedTensors);
Expand Down Expand Up @@ -5175,8 +5176,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
PatternRewriter &rewriter) const override {
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::NoneType>()) {
BaseTensorType tensorType =
cast<BaseTensorType>(op.getSelf().getType());
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ class RecomposeSelectFill_ : public OpRewritePattern<AtenFill_TensorOp> {

// Create IndexPut_Op
// Convert indexNum to indexTensor for the selectOp
BaseTensorType selectOutTy =
cast<BaseTensorType>(selectOp.getType());
BaseTensorType selectOutTy = cast<BaseTensorType>(selectOp.getType());
SmallVector<int64_t> empty;
auto dtype = getTypeForTorchType(selectOp.getContext(),
selectOp.getIndex().getType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
Torch::ValueTensorType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0]isa<TensorType>(.getType()));
assert(inputs[0] isa<TensorType>(.getType()));
return builder.create<FromBuiltinTensorOp>(loc, type, inputs[0]);
};
typeConverter.addSourceMaterialization(sourceMaterialization);
Expand All @@ -64,13 +64,13 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
if (!(type.getWidth() == 1 && type.isSignless()))
return std::nullopt;
assert(inputs.size() == 1);
assert(inputs[0]isa<Torch::BoolType>(.getType()));
assert(inputs[0] isa<Torch::BoolType>(.getType()));
return builder.create<ToI1Op>(loc, inputs[0]).getResult();
});
auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0]isa<IntegerType>(.getType()));
assert(inputs[0] isa<IntegerType>(.getType()));
return builder.create<FromI1Op>(loc, inputs[0]);
};
typeConverter.addSourceMaterialization(sourceMaterialization);
Expand Down

0 comments on commit 22bab6a

Please sign in to comment.