Skip to content

Commit c9694c6

Browse files
authored
Add a CMake option to enable TOSA. Default to ON. (llvm#4021)
Fixes llvm#4019. --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent ddc180f commit c9694c6

File tree

12 files changed

+62
-24
lines changed

12 files changed

+62
-24
lines changed

CMakeLists.txt

+10
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,17 @@ option(TORCH_MLIR_ENABLE_WERROR_FLAG "Enable `-Werror` flag on supported directo
3535
option(TORCH_MLIR_USE_INSTALLED_PYTORCH "If depending on PyTorch use it as installed in the current Python environment" ON)
3636

3737
option(TORCH_MLIR_ENABLE_REFBACKEND "Enable reference backend" ON)
38+
3839
if(TORCH_MLIR_ENABLE_REFBACKEND)
3940
add_definitions(-DTORCH_MLIR_ENABLE_REFBACKEND)
4041
endif()
4142

43+
set(TORCH_MLIR_TABLEGEN_FLAGS "")
44+
4245
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
4346
if(TORCH_MLIR_ENABLE_STABLEHLO)
4447
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
48+
list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_STABLEHLO")
4549
endif()
4650
# It is possible that both stablehlo and torch_mlir projects are used in some compiler project.
4751
# In this case, we don't want to use stablehlo that is downloaded by torch_mlir (in external/stablehlo)
@@ -50,6 +54,12 @@ endif()
5054
# stablehlo targets AND includes available (for example with `add_subdirectory` and `include_directories`).
5155
option(TORCH_MLIR_USE_EXTERNAL_STABLEHLO "Use stablehlo from top level project" OFF)
5256

57+
option(TORCH_MLIR_ENABLE_TOSA "Add TOSA support" ON)
58+
if(TORCH_MLIR_ENABLE_TOSA)
59+
add_definitions(-DTORCH_MLIR_ENABLE_TOSA)
60+
list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_TOSA")
61+
endif()
62+
5363
option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF)
5464

5565
# PyTorch native extension gate. If OFF, then no features which depend on
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
add_subdirectory(TorchOnnxToTorch)
22

33
set(LLVM_TARGET_DEFINITIONS Passes.td)
4-
if(TORCH_MLIR_ENABLE_STABLEHLO)
5-
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
6-
else()
7-
mlir_tablegen(Passes.h.inc -gen-pass-decls)
8-
endif()
4+
5+
6+
7+
mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS})
8+
99
add_public_tablegen_target(TorchMLIRConversionPassIncGen)
1010

1111
add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc)

include/torch-mlir/Conversion/Passes.td

+2
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> {
114114
let constructor = "mlir::torch::createConvertTorchToTensorPass()";
115115
}
116116

117+
#ifdef TORCH_MLIR_ENABLE_TOSA
117118
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
118119
let summary = "Convert Torch ops to TOSA ops";
119120
let description = [{
@@ -122,6 +123,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
122123
}];
123124
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
124125
}
126+
#endif
125127

126128
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
127129
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
set(LLVM_TARGET_DEFINITIONS Passes.td)
2-
if(TORCH_MLIR_ENABLE_STABLEHLO)
3-
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
4-
else()
5-
mlir_tablegen(Passes.h.inc -gen-pass-decls)
6-
endif()
2+
3+
mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS})
4+
75
add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen)
86

97
add_mlir_doc(Passes TorchMLIRTorchConversionTransforms ./ -gen-pass-doc)

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,15 @@ namespace TorchConversion {
2626
/// linalg-on-tensors backend contract.
2727
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
2828

29+
// Do not register the TOSA options if the TOSA target is disabled
30+
#ifdef TORCH_MLIR_ENABLE_TOSA
2931
/// Creates a pipeline that lowers from the torch backend contract to the
3032
/// TOSA backend contract.
3133
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
3234

35+
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
36+
#endif // TORCH_MLIR_ENABLE_TOSA
37+
3338
// Do not register the stablehlo options if the stablehlo target is disabled
3439
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
3540
struct StablehloBackendPipelineOptions
@@ -57,7 +62,7 @@ createFinalizingBackendTypeConversionForStablehloPass();
5762

5863
std::unique_ptr<OperationPass<ModuleOp>>
5964
createVerifyStablehloBackendContractPass();
60-
#endif
65+
#endif // TORCH_MLIR_ENABLE_STABLEHLO
6166

6267
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
6368

@@ -77,8 +82,6 @@ createConvertCustomQuantOpPass();
7782
std::unique_ptr<OperationPass<ModuleOp>>
7883
createVerifyLinalgOnTensorsBackendContractPass();
7984

80-
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
81-
8285
} // namespace TorchConversion
8386

8487
/// Registers all Torch transformation passes.

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td

+2
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,12 @@ def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-
6161
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
6262
}
6363

64+
#ifdef TORCH_MLIR_ENABLE_TOSA
6465
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
6566
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
6667
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
6768
}
69+
#endif
6870

6971
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
7072
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {

lib/CMakeLists.txt

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ set(LinkedLibs
1414
MLIRSCFDialect
1515
MLIRTensorDialect
1616
MLIRTensorInferTypeOpInterfaceImpl
17-
MLIRTosaDialect
1817
MLIRSupport
1918

2019
# Dialects.
@@ -33,7 +32,11 @@ set(LinkedLibs
3332
)
3433

3534
if(TORCH_MLIR_ENABLE_STABLEHLO)
36-
list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses)
35+
list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses)
36+
endif()
37+
38+
if(TORCH_MLIR_ENABLE_TOSA)
39+
list(APPEND LinkedLibs MLIRTosaDialect)
3740
endif()
3841

3942
if(TORCH_MLIR_ENABLE_REFBACKEND)

lib/Conversion/CMakeLists.txt

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ add_subdirectory(TorchToArith)
33
add_subdirectory(TorchToLinalg)
44
add_subdirectory(TorchToSCF)
55
add_subdirectory(TorchToTensor)
6-
add_subdirectory(TorchToTosa)
6+
if(TORCH_MLIR_ENABLE_TOSA)
7+
add_subdirectory(TorchToTosa)
8+
endif()
79
if(TORCH_MLIR_ENABLE_STABLEHLO)
810
add_subdirectory(TorchToStablehlo)
911
endif()
@@ -16,13 +18,15 @@ set(linked_libs TorchMLIRTorchToArith
1618
TorchMLIRTorchToLinalg
1719
TorchMLIRTorchToSCF
1820
TorchMLIRTorchToTensor
19-
TorchMLIRTorchToTosa
2021
TorchMLIRTorchToTMTensor
2122
TorchMLIRTorchConversionToMLProgram
2223
TorchMLIRConversionUtils)
2324
if(TORCH_MLIR_ENABLE_STABLEHLO)
2425
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
2526
endif()
27+
if(TORCH_MLIR_ENABLE_TOSA)
28+
list(APPEND linked_libs TorchMLIRTorchToTosa)
29+
endif()
2630

2731
add_mlir_library(TorchMLIRConversionPasses
2832
Passes.cpp

lib/Conversion/Passes.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
2020
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
2121
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
22+
23+
#ifdef TORCH_MLIR_ENABLE_TOSA
2224
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
25+
#endif // TORCH_MLIR_ENABLE_TOSA
2326

2427
//===----------------------------------------------------------------------===//
2528
// Pass registration

lib/Dialect/TorchConversion/Transforms/Passes.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
1919
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
2020
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
21-
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
2221
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
22+
2323
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
2424
#include "stablehlo/transforms/Passes.h"
2525
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
2626
#endif
27-
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
27+
28+
#ifdef TORCH_MLIR_ENABLE_TOSA
29+
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
30+
using namespace mlir::tosa;
31+
#endif
2832

2933
using namespace mlir;
3034
using namespace mlir::torch;
31-
using namespace mlir::tosa;
3235

3336
//===----------------------------------------------------------------------===//
3437
// Pass registration
@@ -46,12 +49,13 @@ void mlir::torch::registerTorchConversionPasses() {
4649
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
4750
"contract.",
4851
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
49-
52+
#ifdef TORCH_MLIR_ENABLE_TOSA
5053
mlir::PassPipelineRegistration<>(
5154
"torch-backend-to-tosa-backend-pipeline",
5255
"Pipeline lowering torch backend contract to TOSA backend "
5356
"contract.",
5457
TorchConversion::createTorchBackendToTosaBackendPipeline);
58+
#endif
5559
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
5660
mlir::PassPipelineRegistration<
5761
TorchConversion::StablehloBackendPipelineOptions>(
@@ -107,6 +111,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
107111
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
108112
}
109113

114+
#ifdef TORCH_MLIR_ENABLE_TOSA
110115
void TorchConversion::createTorchBackendToTosaBackendPipeline(
111116
OpPassManager &pm) {
112117
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
@@ -130,6 +135,7 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
130135
// correct form.
131136
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
132137
}
138+
#endif
133139

134140
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
135141
void TorchConversion::createTorchBackendToStablehloBackendPipeline(

lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
// Also available under a BSD-style license. See LICENSE.
77
//
88
//===----------------------------------------------------------------------===//
9-
9+
#ifdef TORCH_MLIR_ENABLE_TOSA
1010
#include "PassDetail.h"
1111

1212
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -63,3 +63,4 @@ std::unique_ptr<OperationPass<ModuleOp>>
6363
mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() {
6464
return std::make_unique<VerifyTosaBackendContractPass>();
6565
}
66+
#endif // TORCH_MLIR_ENABLE_TOSA

lib/InitAll.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
2020
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2121
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
22-
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
2322
#include "mlir/IR/Dialect.h"
2423
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
2524
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
@@ -36,6 +35,10 @@
3635
#include "stablehlo/transforms/Passes.h"
3736
#endif
3837

38+
#ifdef TORCH_MLIR_ENABLE_TOSA
39+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
40+
#endif
41+
3942
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
4043
registry.insert<mlir::func::FuncDialect>();
4144
registry.insert<mlir::torch::Torch::TorchDialect>();
@@ -54,7 +57,10 @@ void mlir::torch::registerOptionalInputDialects(
5457
registry.insert<complex::ComplexDialect, linalg::LinalgDialect,
5558
memref::MemRefDialect, ml_program::MLProgramDialect,
5659
scf::SCFDialect, sparse_tensor::SparseTensorDialect,
57-
tensor::TensorDialect, tosa::TosaDialect>();
60+
tensor::TensorDialect>();
61+
#ifdef TORCH_MLIR_ENABLE_TOSA
62+
registry.insert<tosa::TosaDialect>();
63+
#endif
5864
}
5965

6066
void mlir::torch::registerAllPasses() {

0 commit comments

Comments
 (0)