Skip to content

[mlir] Add isStatic* size check for ShapedTypes. NFCI. #147085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions mlir/include/mlir-c/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,12 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type);
/// Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type);

/// Checks wither the dim-th dimension of the given shaped type is dynamic.
/// Checks whether the dim-th dimension of the given shaped type is dynamic.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim);

/// Checks whether the dim-th dimension of the given shaped type is static.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim);

/// Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
intptr_t dim);
Expand All @@ -300,17 +303,25 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
/// in shaped types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size);

/// Checks whether the given shaped type dimension value is statically-sized.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size);

/// Returns the value indicating a dynamic size in a shaped type. Prefer
/// mlirShapedTypeIsDynamicSize to direct comparisons with this value.
/// mlirShapedTypeIsDynamicSize and mlirShapedTypeIsStaticSize to direct
/// comparisons with this value.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void);

/// Checks whether the given value is used as a placeholder for dynamic strides
/// and offsets in shaped types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);

/// Checks whether the given dimension value of a stride or an offset is
/// statically-sized.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val);

/// Returns the value indicating a dynamic stride or offset in a shaped type.
/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with
/// this value.
/// Prefer mlirShapedTypeIsDynamicStrideOrOffset and
/// mlirShapedTypeIsStaticStrideOrOffset to direct comparisons with this value.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void);

//===----------------------------------------------------------------------===//
Expand Down
23 changes: 20 additions & 3 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
This may change in the future, for example, to require types to provide
their size or alignment given a data layout. Please post an RFC before
adding this interface to additional types. Implementing this interface on
downstream types is discourged, until we specified the exact properties of
downstream types is discouraged, until we specified the exact properties of
a vector element type in more detail.
}];
}
Expand Down Expand Up @@ -221,7 +221,17 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {

/// Whether the given shape has any size that indicates a dynamic dimension.
static bool isDynamicShape(ArrayRef<int64_t> dSizes) {
return any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); });
return llvm::any_of(dSizes, isDynamic);
}

/// Whether the given dimension size indicates a statically-sized dimension.
static constexpr bool isStatic(int64_t dValue) {
return dValue != kDynamic;
}

/// Whether the given shape has static dimensions only.
static bool isStaticShape(ArrayRef<int64_t> dSizes) {
return llvm::all_of(dSizes, isStatic);
}

/// Return the number of elements present in the given shape.
Expand Down Expand Up @@ -273,11 +283,18 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]);
}

/// Returns true if this dimension has a static size (for ranked types);
/// aborts for unranked types.
bool isStaticDim(unsigned idx) const {
assert(idx < getRank() && "invalid index for shaped type");
return ::mlir::ShapedType::isStatic($_type.getShape()[idx]);
}

/// Returns if this type has a static shape, i.e. if the type is ranked and
/// all dimensions have known size (>= 0).
bool hasStaticShape() const {
return $_type.hasRank() &&
!::mlir::ShapedType::isDynamicShape($_type.getShape());
::mlir::ShapedType::isStaticShape($_type.getShape());
}

/// Returns if this type has a static shape and the shape is equal to
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
nb::arg("dim"),
"Returns whether the dim-th dimension of the given shaped type is "
"dynamic.");
c.def(
"is_static_dim",
[](PyShapedType &self, intptr_t dim) -> bool {
self.requireHasRank();
return mlirShapedTypeIsStaticDim(self, dim);
},
nb::arg("dim"),
"Returns whether the dim-th dimension of the given shaped type is "
"static.");
c.def(
"get_dim_size",
[](PyShapedType &self, intptr_t dim) {
Expand All @@ -558,6 +567,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
nb::arg("dim_size"),
"Returns whether the given dimension size indicates a dynamic "
"dimension.");
c.def_static(
"is_static_size",
[](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
nb::arg("dim_size"),
"Returns whether the given dimension size indicates a static "
"dimension.");
c.def(
"is_dynamic_stride_or_offset",
[](PyShapedType &self, int64_t val) -> bool {
Expand All @@ -567,6 +582,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
nb::arg("dim_size"),
"Returns whether the given value is used as a placeholder for dynamic "
"strides and offsets in shaped types.");
c.def(
"is_static_stride_or_offset",
[](PyShapedType &self, int64_t val) -> bool {
self.requireHasRank();
return mlirShapedTypeIsStaticStrideOrOffset(val);
},
nb::arg("dim_size"),
"Returns whether the given shaped type stride or offset value is "
"statically-sized.");
c.def_prop_ro(
"shape",
[](PyShapedType &self) {
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
.isDynamicDim(static_cast<unsigned>(dim));
}

bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim) {
return llvm::cast<ShapedType>(unwrap(type))
.isStaticDim(static_cast<unsigned>(dim));
}

int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
return llvm::cast<ShapedType>(unwrap(type))
.getDimSize(static_cast<unsigned>(dim));
Expand All @@ -343,10 +348,18 @@ bool mlirShapedTypeIsDynamicSize(int64_t size) {
return ShapedType::isDynamic(size);
}

bool mlirShapedTypeIsStaticSize(int64_t size) {
return ShapedType::isStatic(size);
}

bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
return ShapedType::isDynamic(val);
}

bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val) {
return ShapedType::isStatic(val);
}

int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
return ShapedType::kDynamic;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(

// Extract all strides and offsets and verify they are static.
auto [strides, offset] = type.getStridesAndOffset();
assert(!ShapedType::isDynamic(offset) && "expected static offset");
assert(ShapedType::isStatic(offset) && "expected static offset");
assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
"expected static strides");

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
if (ShapedType::isDynamic(stride))
return false;

return !ShapedType::isDynamic(offset);
return ShapedType::isStatic(offset);
}

/// Convert a memref type to a bare pointer to the memref element type.
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
namespace {

static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
return !ShapedType::isDynamic(strideOrOffset);
return ShapedType::isStatic(strideOrOffset);
}

static FailureOr<LLVM::LLVMFuncOp>
Expand Down Expand Up @@ -1468,7 +1468,7 @@ struct MemRefReshapeOpLowering
Value stride = nullptr;
int64_t targetRank = targetMemRefType.getRank();
for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
if (!ShapedType::isDynamic(strides[i])) {
if (ShapedType::isStatic(strides[i])) {
// If the stride for this dimension is dynamic, then use the product
// of the sizes of the inner dimensions.
stride =
Expand Down Expand Up @@ -1722,7 +1722,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
Type indexType) const {
assert(idx < shape.size());
if (!ShapedType::isDynamic(shape[idx]))
if (ShapedType::isStatic(shape[idx]))
return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
// Count the number of dynamic dims in range [0, idx]
unsigned nDynamic =
Expand All @@ -1738,7 +1738,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
ArrayRef<int64_t> strides, Value nextSize,
Value runningStride, unsigned idx, Type indexType) const {
assert(idx < strides.size());
if (!ShapedType::isDynamic(strides[idx]))
if (ShapedType::isStatic(strides[idx]))
return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
if (nextSize)
return runningStride
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
// dimension greater than 1 with a different value is undefined behavior.
for (auto operand : operands) {
auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
if (!ShapedType::isDynamic(size) && size > 1)
if (ShapedType::isStatic(size) && size > 1)
return {rewriter.getIndexAttr(size), operand};
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
return totalSize / totalSizeNoPlaceholder;
});

bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
bool resultIsStatic = ShapedType::isStaticShape(resultShape);

// A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
// shaped input from being reshaped into a statically shaped result. We may
Expand Down Expand Up @@ -305,7 +305,7 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
int64_t size = i.value();
size_t index = i.index();
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
if (!ShapedType::isDynamic(sizes.back()))
if (ShapedType::isStatic(sizes.back()))
continue;

auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
failed(target.getStridesAndOffset(targetStrides, targetOffset)))
return false;
auto dynamicToStatic = [](int64_t a, int64_t b) {
return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
return ShapedType::isDynamic(a) && ShapedType::isStatic(b);
};
if (dynamicToStatic(sourceOffset, targetOffset))
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static bool hasFullyDynamicLayoutMap(MemRefType type) {
return false;
if (!llvm::all_of(strides, ShapedType::isDynamic))
return false;
if (!ShapedType::isDynamic(offset))
if (ShapedType::isStatic(offset))
return false;
return true;
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4564,7 +4564,7 @@ static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
SmallVector<OpFoldResult> mixedInnerTiles;
unsigned dynamicValIndex = 0;
for (int64_t staticTile : op.getStaticInnerTiles()) {
if (!ShapedType::isDynamic(staticTile))
if (ShapedType::isStatic(staticTile))
mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
else
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
Expand Down Expand Up @@ -4829,7 +4829,7 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);

if (!constantTile) {
if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
if (ShapedType::isStatic(outputTileSizes[pos]) &&
(inputShape[pos] % outputTileSizes[pos] != 0))
return true;
} else if (inputShape[pos] % (*constantTile) != 0) {
Expand Down Expand Up @@ -4935,7 +4935,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
// use dispatchIndexOpFoldResults on the result, and rely on exact number of
// dynamic dims returned by that.
for (unsigned i = 0; i < resultDims.size(); ++i) {
if (!ShapedType::isDynamic(resultTypeShape[i]))
if (ShapedType::isStatic(resultTypeShape[i]))
continue;
resultDims[i] =
getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2061,7 +2061,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
rewriter.setInsertionPoint(linalgTarget);
for (OpOperand &operand : linalgTarget->getOpOperands()) {
for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
if (!ShapedType::isDynamic(dim))
if (ShapedType::isStatic(dim))
continue;
options.setSizeToPadTo(operand.getOperandNumber(), i,
tensor::getMixedSize(rewriter,
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
LinalgOp linalgOp) {
// TODO: Support 0-d vectors.
for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
// Create constant index op for static dimensions.
iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
Expand Down Expand Up @@ -1652,7 +1652,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
for (unsigned i = 0; i < vecToStoreRank; i++)
inBoundsVal[i] =
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
}

// If missing, initialize the write indices to 0.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
int64_t shapeSize = shape[r];
std::optional<int64_t> sizeCst = getConstantIntValue(size);
auto hasTileSizeOne = sizeCst == 1;
auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
((shapeSize % *sizeCst) == 0);
if (!hasTileSizeOne && !dividesEvenly) {
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
Expand Down
Loading
Loading