Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump llvm to llvm/llvm-project@5d6d982 #3994

Merged
merged 7 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 4066 files
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 52 files
+1 −1 BUILD.bazel
+2 −2 WORKSPACE.bazel
+1 −1 build_tools/llvm_version.txt
+17 −1 docs/awesome.md
+3 −16 stablehlo/conversions/linalg/transforms/Rewriters.h
+39 −38 stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp
+23 −0 stablehlo/dialect/AssemblyFormat.cpp
+59 −0 stablehlo/dialect/AssemblyFormat.h
+17 −0 stablehlo/dialect/Base.cpp
+3 −0 stablehlo/dialect/Base.h
+17 −0 stablehlo/dialect/Base.td
+1 −1 stablehlo/dialect/CMakeLists.txt
+15 −0 stablehlo/dialect/StablehloAttrs.td
+79 −7 stablehlo/dialect/StablehloBytecode.cpp
+23 −0 stablehlo/dialect/StablehloEnums.td
+38 −0 stablehlo/dialect/StablehloOps.cpp
+19 −2 stablehlo/dialect/StablehloOps.td
+29 −4 stablehlo/dialect/TypeInference.cpp
+9 −0 stablehlo/dialect/TypeInference.h
+3 −3 stablehlo/dialect/Version.cpp
+1 −1 stablehlo/dialect/Version.h
+24 −11 stablehlo/dialect/VhloAttrs.td
+74 −1 stablehlo/dialect/VhloBytecode.cpp
+1 −0 stablehlo/dialect/VhloDialect.td
+33 −1 stablehlo/dialect/VhloEnums.td
+9 −8 stablehlo/dialect/VhloOps.cpp
+8 −1 stablehlo/dialect/VhloOps.td
+67 −0 stablehlo/integrations/c/StablehloAttributes.cpp
+37 −0 stablehlo/integrations/c/StablehloAttributes.h
+44 −0 stablehlo/integrations/python/StablehloModule.cpp
+21 −0 stablehlo/integrations/python/tests/stablehlo.py
+6 −7 stablehlo/reference/Types.cpp
+40 −0 stablehlo/tests/ops_stablehlo.mlir
+63 −0 stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
+5 −0 stablehlo/tests/ops_stablehlo_roundtrip.mlir
+13 −0 stablehlo/tests/print_stablehlo.mlir
+11 −0 stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+2,966 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir.bc
+31 −1 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
+26 −0 stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir
+24 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir
+22 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir
+1 −1 stablehlo/transforms/MapStablehloToVhlo.h
+3 −3 stablehlo/transforms/PassUtils.h
+5 −0 stablehlo/transforms/Passes.h
+20 −2 stablehlo/transforms/StablehloAggressiveSimplification.cpp
+6 −3 stablehlo/transforms/StablehloComplexMathExpanderPatterns.td
+24 −0 stablehlo/transforms/StablehloLegalizeToVhlo.cpp
+24 −0 stablehlo/transforms/VhloLegalizeToStablehlo.cpp
+53 −0 stablehlo/transforms/VhloToVersion.cpp
+16 −0 stablehlo/transforms/VhloToVersionPatterns.td
5 changes: 0 additions & 5 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,6 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
RankedTensorType weightTy,
RankedTensorType outputTy, TypeAttr &accType);

// Temporary function to get TOSA const shape
// TODO: Remove this function when getTosaConstShape is available in
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
} // namespace tosa
} // namespace mlir

Expand Down
518 changes: 395 additions & 123 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp

Large diffs are not rendered by default.

59 changes: 49 additions & 10 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,

// %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) ->
// tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix.
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
Expand All @@ -378,13 +378,18 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
if (!flattenedCoeffValue)
return std::nullopt;

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp,
flattenedCoeffValue.value())
.failed())
return std::nullopt;

// Multiply the coefficients by the coordinates
// %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>,
// tensor<3xi32>) -> tensor<8x3xi32>
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);

// Sum up the products of the coefficients and coordinates
// %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) ->
Expand Down Expand Up @@ -616,7 +621,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
// [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]]
// %11 = "tosa.reshape"(%8) {new_shape = array<i64: 3, 2>} : (tensor<3x2xi32>)
// -> tensor<3x2xi32>
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
Expand All @@ -643,14 +648,19 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
if (!flattenedCoeffValue)
return std::nullopt;

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp,
flattenedCoeffValue.value())
.failed())
return std::nullopt;

// Multiply the coefficients by the coordinates.
// [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]]
// %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>,
// tensor<2xi32>) -> tensor<3x2xi32>
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0);

// Sum up the products of the coefficients and coordinates
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
Expand Down Expand Up @@ -734,10 +744,20 @@ std::optional<Value> convertReduceOpCommon(
RankedTensorType reduce_type =
RankedTensorType::get(shape_vec, reduce_element_type);

auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);
Value reduce_op;
if constexpr (std::is_same<T, tosa::ReduceMinOp>() ||
std::is_same<T, tosa::ReduceMaxOp>()) {
// Use default NaN Propagation mode "PROPAGATE" for tosa.reduce_min
// and tosa.reduce_max
reduce_op = CreateOpAndInfer<T>(
rewriter, op->getLoc(), reduce_type, val, axis_attr,
/*nan_mode=*/rewriter.getStringAttr("PROPAGATE"));
} else {
reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);
}

val = reduce_op.getResult();
val = reduce_op;
}

if (is_quantized) {
Expand Down Expand Up @@ -973,6 +993,12 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,

if (!input_is_qtype) {
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), val.value(),
div_const)
.failed())
return std::nullopt;

return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
val.value(), div_const, 0)
.getResult();
Expand Down Expand Up @@ -1021,6 +1047,11 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt;
}

Value ordValRank0 = ordVal;
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input_value, ordVal)
.failed())
return std::nullopt;

if (fabs(ordLiteralFloat) < epsilon ||
fabs(static_cast<double>(ordLiteralInt)) < epsilon) {
op->emitOpError("unimplemented: L0 norm");
Expand Down Expand Up @@ -1049,9 +1080,17 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
rewriter, op, output_type, powVal, axes_elems, keep_dims);
if (!result)
return std::nullopt;
auto reciprocalVal = CreateOpAndInfer<tosa::ReciprocalOp>(
rewriter, op->getLoc(), ordVal.getType(), ordVal)
.getResult();

Value reciprocalVal =
CreateOpAndInfer<tosa::ReciprocalOp>(rewriter, op->getLoc(),
ordValRank0.getType(), ordValRank0)
.getResult();

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), result.value(),
reciprocalVal)
.failed())
return std::nullopt;

return CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(), output_type,
result.value(), reciprocalVal)
.getResult();
Expand Down
52 changes: 23 additions & 29 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project

namespace mlir {
Expand Down Expand Up @@ -301,31 +302,31 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && dest.isFloat8E4M3()) ||
(src.isF32() && dest.isFloat8E5M2()) ||
(src.isF32() && isa<Float8E4M3Type>(dest)) ||
(src.isF32() && isa<Float8E5M2Type>(dest)) ||
// f16 -> *
(src.isF16() && dest.isInteger(32)) ||
(src.isF16() && dest.isInteger(16)) ||
(src.isF16() && dest.isInteger(8)) ||
(src.isF16() && dest.isBF16()) ||
(src.isF16() && dest.isF32()) ||
(src.isF16() && dest.isFloat8E4M3()) ||
(src.isF16() && dest.isFloat8E5M2()) ||
(src.isF16() && isa<Float8E4M3Type>(dest)) ||
(src.isF16() && isa<Float8E5M2Type>(dest)) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isF32()) ||
(src.isBF16() && dest.isFloat8E4M3()) ||
(src.isBF16() && dest.isFloat8E5M2()) ||
(src.isBF16() && isa<Float8E4M3Type>(dest)) ||
(src.isBF16() && isa<Float8E5M2Type>(dest)) ||
// fp8e4m3 -> *
(src.isFloat8E4M3() && dest.isBF16()) ||
(src.isFloat8E4M3() && dest.isF32()) ||
(src.isFloat8E4M3() && dest.isF16()) ||
(isa<Float8E4M3Type>(src) && dest.isBF16()) ||
(isa<Float8E4M3Type>(src) && dest.isF32()) ||
(isa<Float8E4M3Type>(src) && dest.isF16()) ||
// fp8e5m2 -> *
(src.isFloat8E5M2() && dest.isBF16()) ||
(src.isFloat8E5M2() && dest.isF32()) ||
(src.isFloat8E5M2() && dest.isF16())) {
(isa<Float8E5M2Type>(src) && dest.isBF16()) ||
(isa<Float8E5M2Type>(src) && dest.isF32()) ||
(isa<Float8E5M2Type>(src) && dest.isF16())) {
return success();
}
// clang-format on
Expand Down Expand Up @@ -393,6 +394,11 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
auto zeroValue =
tosa::getConstTensor<float>(rewriter, op, 0, {}, srcElemTy).value();

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue)
.failed())
return rewriter.notifyMatchFailure(
op, "Failed to equalize ranks among operands and result");

auto boolType = srcType.clone(rewriter.getIntegerType(1));
auto isNegative = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, zeroValue, src);
Expand Down Expand Up @@ -488,10 +494,10 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
outputElemTy.isInteger(48)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
} else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() &&
outputElemTy.isF16()) ||
(inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() &&
outputElemTy.isF16())) {
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
(isa<Float8E5M2Type>(inputElemTy) &&
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
accType = mlir::TypeAttr::get(rewriter.getF16Type());
} else {
accType = mlir::TypeAttr::get(outputElemTy);
Expand All @@ -500,17 +506,5 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
return success();
}

// Temporary function to get TOSA const shape
// TODO: Remove this function when getTosaConstShape is available in
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape) {
auto attr = rewriter.getIndexTensorAttr(shape);
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
mlir::Operation *mlir_op =
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
return mlir_op->getResult(0);
}

} // namespace tosa
} // namespace mlir
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,13 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getF32Type();
if (isa<Float64Type>(inputType))
return rewriter.getF64Type();
if (inputType.isFloat8E5M2())
if (isa<Float8E5M2Type>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FN())
if (isa<Float8E4M3FNType>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(inputType))
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FNUZ())
if (isa<Float8E4M3FNUZType>(inputType))
return rewriter.getF32Type();
if (inputType.isInteger(8))
// this is an intentional deviation from CUDA (which accumulates i8 to i64)
Expand Down
Loading
Loading