diff --git a/backend/npu.py b/backend/npu.py index 11bba16e..fd90e9e4 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -346,7 +346,7 @@ def ttsharedir_to_dicp(mod, metadata, opt, *, named_ops=False): # content = content.replace('func.func @_silu_and_mul_kernel(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32)', # 'func.func @_silu_and_mul_kernel(%arg1000: memref, %arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) attributes {WorkspaceArgIdx = 0 : i64, global_kernel = "local", mix_mode = "aiv"}') # 将"*xf16"替换成"?xf16" - content = content.replace("*xf16", "?xf16") + content = content.replace("*xf", "?xf") print(f"zmz debug: after replace content: {content}") # 将context 写回去 with open(dst_path, 'w') as f: diff --git a/compiler/include/dicp/Conversion/LinalgToNPU/ConversionPatterns.hpp b/compiler/include/dicp/Conversion/LinalgToNPU/ConversionPatterns.hpp index 75107975..e2d12e24 100644 --- a/compiler/include/dicp/Conversion/LinalgToNPU/ConversionPatterns.hpp +++ b/compiler/include/dicp/Conversion/LinalgToNPU/ConversionPatterns.hpp @@ -11,6 +11,15 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" +#include "mlir/IR/BuiltinAttributes.h" +// #include "mlir/Dialect/Arith/IR/ArithAttributes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +// #include "mlir/Dialect/Linalg/IR/LinalgAttributes.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include #include @@ -19,6 +28,7 @@ using namespace mlir; using namespace dicp; +using namespace mlir::utils; namespace { @@ -331,4 +341,180 @@ struct ConvertLinalgGenericToArith }; -} // namespace \ No newline at end of file +// ... 已有代码 ... + + +struct ConvertLinalgGenericToBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + std::cout << "[ConvertLinalgGenericToBroadcast] Starting matchAndRewrite for LinalgGenericOp at location: "; + genericOp.getLoc().print(llvm::outs()); + std::cout << "\n"; + + // === Step 1: 检查是否是一个合法的 broadcast 模式 === + if (!isBroadcastPattern(genericOp)) { + std::cout << "[INFO] Not a broadcast pattern\n"; + return failure(); + } + std::cout << "[INFO] Detected broadcast pattern\n"; + + // === Step 2: 获取输入输出 Tensor === + Value input = genericOp.getDpsInputOperand(0)->get(); + Value output = genericOp.getDpsInitOperand(0)->get(); + + // 提取 broadcastDims 属性 + auto attr = genericOp->getAttr("broadcastDims"); + if (!attr) { + std::cout << "[ERROR] Missing 'broadcastDims' attribute\n"; + return failure(); + } + + std::cout << "[DEBUG] Raw broadcastDims attribute value: "; + attr.print(llvm::outs()); + std::cout << "\n"; + SmallVector broadcastDims; + if (auto denseAttr = dyn_cast(attr)) { + // broadcastDims = denseAttr.asArrayRef().vec(); + broadcastDims.assign(denseAttr.asArrayRef().begin(), denseAttr.asArrayRef().end()); + } else if (auto arrayAttr = dyn_cast(attr)) { + for (auto element : arrayAttr) { + if (auto intAttr = dyn_cast(element)) { + broadcastDims.push_back(intAttr.getInt()); + } else { + std::cout << "[ERROR] Invalid element in 'broadcastDims' array\n"; + return failure(); + } + } + } else { + std::cout << "[ERROR] Invalid 'broadcastDims' attribute type\n"; + return failure(); + } + + std::cout << "[INFO] Detected broadcastDims = ["; + for (int64_t d : broadcastDims) { + std::cout << d << " "; + } + std::cout << "]\n"; + + // 创建 linalg.broadcast 操作 + rewriter.setInsertionPoint(genericOp); + auto broadcastOp = rewriter.create( + genericOp.getLoc(), + input, + output, + broadcastDims + ); + + // 替换原操作 + // 打印操作 + std::cout << "[DEBUG] Replacing linalg.generic with linalg.broadcast\n"; + std::cout << "[DEBUG] Before replacement:\n"; + genericOp.print(llvm::outs()); + std::cout << "\n"; + // 打印替换后的操作 + std::cout << "[DEBUG] After replacement:\n"; + broadcastOp.print(llvm::outs()); + std::cout << "\n"; + rewriter.replaceOp(genericOp, broadcastOp.getResult()); + std::cout << "[INFO] Replaced linalg.generic with linalg.broadcast\n"; + return success(); + } + +private: + // 判断是否为 broadcast 模式 + bool isBroadcastPattern(linalg::GenericOp op) const { + std::cout << "[DEBUG] Checking broadcast pattern for linalg.generic at location: "; + op.getLoc().print(llvm::outs()); + std::cout << "\n"; + + // 1. 检查迭代器类型 + SmallVector iterTypes; + if (failed(getIteratorTypeNames(op, iterTypes))) { + std::cout << "[DEBUG] Failed to get iterator type names. Not a broadcast pattern.\n"; + return false; + } + if (!llvm::all_of(iterTypes, [](StringRef type) { return type == "parallel"; })) { + std::cout << "[DEBUG] Iterator types are not all 'parallel'.\n"; + return false; + } + std::cout << "[DEBUG] All iterator types are 'parallel'.\n"; + + // 2. 检查输入输出数量 + if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1) { + std::cout << "[DEBUG] Expected 1 input and 1 output.\n"; + return false; + } + std::cout << "[DEBUG] Number of inputs and outputs matched.\n"; + + // 检查 block 结构 + // Region ®ion = op->getRegion(0); + if (op->getNumRegions() == 0) { + std::cout << "[ERROR] Operation has no regions.\n"; + return false; + } + + Region *region = &op->getRegion(0); + if (!region || !region->hasOneBlock()) { + std::cout << "[ERROR] Region not valid or does not have one block.\n"; + return false; + } + std::cout << "[DEBUG] Region has exactly one block.\n"; + + Block &block = region->front(); + if (block.empty() || !isa(block.back())) { + std::cout << "[DEBUG] Block does not end with linalg.yield.\n"; + return false; + } + std::cout << "[DEBUG] Block is not empty and ends with linalg.yield.\n"; + + Operation *innerOp = block.getTerminator()->getPrevNode(); + if (innerOp && !isa(innerOp)) { + std::cout << "[DEBUG] Expected only linalg.yield in block.\n"; + return false; + } + std::cout << "[DEBUG] No inner operation before yield.\n"; + + auto yieldOp = cast(block.getTerminator()); + if (yieldOp->getNumOperands() != 1 || yieldOp->getOperand(0) != block.getArgument(0)) { + std::cout << "[DEBUG] Yield operand mismatch.\n"; + return false; + } + std::cout << "[DEBUG] Yield operand check passed.\n"; + + std::cout << "[DEBUG] Inner operation and yield check passed.\n"; + + // 4. 检查是否有 broadcastDims 属性 + if (!op->hasAttr("broadcastDims")) { + std::cout << "[DEBUG] Missing 'broadcastDims' attribute.\n"; + return false; + } + + std::cout << "[DEBUG] 'broadcastDims' attribute found. This is a broadcast pattern.\n"; + return true; + } + + // 获取 Iterator 类型名称 + LogicalResult getIteratorTypeNames(linalg::GenericOp op, + SmallVectorImpl &types) const { + auto iteratorAttrs = op.getIteratorTypes().getValue(); + for (Attribute attr : iteratorAttrs) { + if (auto iterTypeAttr = dyn_cast(attr)) { + types.push_back(mlir::utils::stringifyIteratorType(iterTypeAttr.getValue())); + } else if (auto strAttr = dyn_cast(attr)) { + types.push_back(strAttr.getValue()); + } else { + std::cout << "[ERROR] Unsupported iterator type attribute.\n"; + return failure(); + } + } + return success(); + } +}; + +// ... 已有代码 ... + + + +} // namespace diff --git a/compiler/include/dicp/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.hpp b/compiler/include/dicp/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.hpp new file mode 100644 index 00000000..54a0a88c --- /dev/null +++ b/compiler/include/dicp/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.hpp @@ -0,0 +1,469 @@ +#ifndef CONVERT_RANKED_TO_UNRANKED_PASS_H +#define CONVERT_RANKED_TO_UNRANKED_PASS_H + +#include "mlir/Pass/Pass.h" +#include +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallVector.h" +#include +// #include "llvm/ADT/Optional.h" // 添加缺失的头文件 + +#define DICPRTU_DEBUG_TYPE "convert-ranked-to-unranked" + +using namespace mlir; +using namespace mlir::func; + +namespace mlir { +namespace dicp { +namespace npu { + +inline bool isLegalForUnrankedTypeConversion(Operation *op, TypeConverter &typeConverter) { + llvm::outs() << "Checking legality of operation: "; + op->getName().print(llvm::outs()); + llvm::outs() << "\n"; + for (Type type : op->getOperandTypes()) { + if (auto memrefTy = dyn_cast(type)) { + if (memrefTy.hasStaticShape()) { + llvm::outs() << "Operand type has static shape, considered legal: "; + type.print(llvm::outs()); + llvm::outs() << "\n"; + continue; + } + } + if (!typeConverter.isLegal(type)) { + llvm::outs() << "Operand type is illegal: "; + type.print(llvm::outs()); + llvm::outs() << "\n"; + return false; + } + } + for (Type type : op->getResultTypes()) { + if (auto memrefTy = dyn_cast(type)) { + if (memrefTy.hasStaticShape()) { + llvm::outs() << "Result type has static shape, considered legal: "; + type.print(llvm::outs()); + llvm::outs() << "\n"; + continue; + } + } + if (!typeConverter.isLegal(type)) { + llvm::outs() << "Result type is illegal: "; + type.print(llvm::outs()); + llvm::outs() << "\n"; + return false; + } + } + llvm::outs() << "Operation is legal: "; + op->getName().print(llvm::outs()); + llvm::outs() << "\n"; + return true; +} + +inline Value convertValueIfNeeded(OpBuilder &builder, Location loc, Value value, const TypeConverter &typeConverter) { + Type originalType = value.getType(); + llvm::outs() << "Converting value with original type: "; + originalType.print(llvm::outs()); + llvm::outs() << "\n"; + Type convertedType = typeConverter.convertType(originalType); + if (convertedType == originalType) { + llvm::outs() << "Type remains the same, no conversion needed.\n"; + return value; + } + if (auto memrefTy = dyn_cast(originalType)) { + llvm::outs() << "Converting memref type to: "; + convertedType.print(llvm::outs()); + llvm::outs() << "\n"; + return builder.create(loc, convertedType, value); + } + llvm::outs() << "No suitable conversion for type, returning original value.\n"; + return value; +} + +struct ScfForOpTypeConversionPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::outs() << "Processing scf.for operation: "; + op->print(llvm::outs()); + llvm::outs() << "\n"; + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) { + llvm::outs() << "Failed to convert result types for scf.for operation.\n"; + return failure(); + } + llvm::outs() << "Converted result types for scf.for operation: "; + for (Type type : newResultTypes) { + type.print(llvm::outs()); + llvm::outs() << " "; + } + llvm::outs() << "\n"; + + SmallVector newInitArgs; + for (Value initArg : adaptor.getInitArgs()) { + newInitArgs.push_back(convertValueIfNeeded(rewriter, op.getLoc(), initArg, *getTypeConverter())); + } + + auto newOp = rewriter.create( + op.getLoc(), + adaptor.getLowerBound(), + adaptor.getUpperBound(), + adaptor.getStep(), + newInitArgs, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange args) { + llvm::outs() << "Inside scf.for loop body, induction variable: "; + iv.print(llvm::outs()); + llvm::outs() << "\n"; + Block &oldBlock = op.getRegion().front(); + llvm::DenseMap valueMap; + valueMap[oldBlock.getArgument(0)] = iv; + auto newArgsIter = args.begin(); + for (auto oldArg : oldBlock.getArguments().drop_front()) { + valueMap[oldArg] = *newArgsIter; + ++newArgsIter; + } + + for (auto &oldOp : oldBlock.without_terminator()) { + SmallVector newOperands; + for (Value operand : oldOp.getOperands()) { + if (valueMap.count(operand)) { + newOperands.push_back(valueMap[operand]); + } else { + newOperands.push_back(operand); + } + } + Operation *newOp = builder.clone(oldOp); + newOp->setOperands(newOperands); + for (auto resultPair : llvm::zip(oldOp.getResults(), newOp->getResults())) { + valueMap[std::get<0>(resultPair)] = std::get<1>(resultPair); + } + } + + auto yieldOp = cast(oldBlock.getTerminator()); + SmallVector newYieldOperands; + for (auto operand : yieldOp.getOperands()) { + if (valueMap.count(operand)) { + newYieldOperands.push_back(valueMap[operand]); + } else { + newYieldOperands.push_back(operand); + } + } + builder.create(loc, newYieldOperands); + } + ); + + rewriter.replaceOp(op, newOp.getResults()); + llvm::outs() << "Replaced scf.for operation with new one.\n"; + return success(); + } +}; + +struct CallOpTypeConversionPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::outs() << "Processing CallOp: "; + op->print(llvm::outs()); + llvm::outs() << "\n"; + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) { + llvm::outs() << "Failed to convert result types for CallOp.\n"; + return failure(); + } + llvm::outs() << "Converted result types for CallOp: "; + for (Type type : newResultTypes) { + type.print(llvm::outs()); + llvm::outs() << " "; + } + llvm::outs() << "\n"; + + SmallVector newOperands; + for (Value operand : adaptor.getOperands()) { + newOperands.push_back(convertValueIfNeeded(rewriter, op.getLoc(), operand, *getTypeConverter())); + } + + auto newOp = rewriter.create(op.getLoc(), op.getCallee(), newResultTypes, newOperands); + rewriter.replaceOp(op, newOp.getResults()); + llvm::outs() << "Replaced CallOp with new one.\n"; + return success(); + } +}; + +struct ReturnOpTypeConversionPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::outs() << "Processing ReturnOp: "; + op->print(llvm::outs()); + llvm::outs() << "\n"; + SmallVector newOperands; + for (Value operand : adaptor.getOperands()) { + newOperands.push_back(convertValueIfNeeded(rewriter, op.getLoc(), operand, *getTypeConverter())); + } + rewriter.replaceOpWithNewOp(op, newOperands); + llvm::outs() << "Replaced ReturnOp with new one.\n"; + return success(); + } +}; + +inline void populateTypeConversionPatterns1(RewritePatternSet &patterns, TypeConverter &typeConverter, ConversionTarget &target) { + populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](FuncOp func) { + bool isLegal = typeConverter.isSignatureLegal(func.getFunctionType()); + llvm::outs() << "FuncOp " << func.getName() << " signature is " << (isLegal ? "legal" : "illegal") << "\n"; + return isLegal; + }); + + patterns.add(typeConverter, patterns.getContext()); + target.addDynamicallyLegalOp( + [&](CallOp op) { + bool isLegal = true; + for (Type type : op.getOperandTypes()) { + if (auto memrefTy = dyn_cast(type)) { + if (memrefTy.hasStaticShape()) continue; + } + if (!typeConverter.isLegal(type)) { + isLegal = false; + break; + } + } + for (Type type : op.getResultTypes()) { + if (auto memrefTy = dyn_cast(type)) { + if (memrefTy.hasStaticShape()) continue; + } + if (!typeConverter.isLegal(type)) { + isLegal = false; + break; + } + } + llvm::outs() << "CallOp " << op.getCallee() << " is " << (isLegal ? "legal" : "illegal") << "\n"; + return isLegal; + }); + + patterns.add(typeConverter, patterns.getContext()); + target.addDynamicallyLegalOp( + [&](ReturnOp op) { + bool isLegal = true; + for (Type type : op.getOperandTypes()) { + if (auto memrefTy = dyn_cast(type)) { + if (memrefTy.hasStaticShape()) continue; + } + if (!typeConverter.isLegal(type)) { + isLegal = false; + break; + } + } + llvm::outs() << "ReturnOp is " << (isLegal ? "legal" : "illegal") << "\n"; + return isLegal; + }); + + patterns.add(typeConverter, patterns.getContext()); + target.addDynamicallyLegalOp( + [&](scf::ForOp op) { + bool isLegal = isLegalForUnrankedTypeConversion(op, typeConverter); + llvm::outs() << "scf.for operation is " << (isLegal ? "legal" : "illegal") << "\n"; + return isLegal; + }); +} + +struct UnrealizedConversionCastPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getInputs().size() != 1 || op.getOutputs().size() != 1) { + return failure(); + } + + Value input = op.getInputs()[0]; + Type resultType = op.getOutputs()[0].getType(); + + if (isa(input.getType()) && isa(resultType)) { + rewriter.replaceOpWithNewOp(op, resultType, input); + return success(); + } + + rewriter.replaceOpWithNewOp(op, resultType, input); + return success(); + } +}; + +struct ConvertRankedToUnrankedPass : public PassWrapper> { + + void runOnOperation() override { + llvm::outs() << "Starting ConvertRankedToUnrankedPass on module: "; + getOperation()->print(llvm::outs()); + llvm::outs() << "\n"; + auto module = getOperation(); + MLIRContext *ctx = &getContext(); + + TypeConverter typeConverter; + + typeConverter.addConversion([&](Type t) -> Type { + llvm::outs() << "Checking type for conversion: "; + t.print(llvm::outs()); + llvm::outs() << "\n"; + + // 使用 isa 进行类型检查 + if (isa(t)) { + llvm::outs() << "Type is recognized as MemRefType before dyn_cast.\n"; + } else { + llvm::outs() << "Type is not recognized as MemRefType before dyn_cast.\n"; + } + + if (auto memrefTy = dyn_cast(t)) { + llvm::outs() << "MemRef type details: "; + memrefTy.print(llvm::outs()); + llvm::outs() << "\n"; + + bool hasDynamicDim = false; + for (int64_t dim : memrefTy.getShape()) { + if (ShapedType::isDynamic(dim)) { + hasDynamicDim = true; + break; + } + } + + if (hasDynamicDim) { + llvm::outs() << "Converting MemRef with dynamic shape to single-dimensional dynamic shape\n"; + return MemRefType::get( + {ShapedType::kDynamic}, + memrefTy.getElementType(), + AffineMap(), + memrefTy.getMemorySpace() + ); + } + + llvm::outs() << "No dynamic dimension found, no conversion needed.\n"; + } else { + llvm::outs() << "dyn_cast to MemRefType failed for type: "; + t.print(llvm::outs()); + llvm::outs() << "\n"; + } + + llvm::outs() << "No conversion needed for type.\n"; + return t; + }); + + // 修复材料化函数的返回类型 + typeConverter.addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + llvm::outs() << "Source materialization called. Inputs count: " << inputs.size() << "\n"; + if (inputs.size() != 1) { + llvm::outs() << "Expected 1 input, got " << inputs.size() << ", returning nullptr.\n"; + return nullptr; + } + + Value input = inputs[0]; + if (isa(input.getType()) && isa(resultType)) { + llvm::outs() << "Creating memref::CastOp for source materialization\n"; + return builder.create(loc, resultType, input); + } + + llvm::outs() << "No suitable cast operation for source materialization\n"; + return nullptr; + }); + + typeConverter.addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + llvm::outs() << "Target materialization called. Inputs count: " << inputs.size() << "\n"; + if (inputs.size() != 1) { + llvm::outs() << "Expected 1 input, got " << inputs.size() << ", returning nullptr.\n"; + return nullptr; + } + + Value input = inputs[0]; + if (isa(input.getType()) && isa(resultType)) { + llvm::outs() << "Creating memref::CastOp for target materialization\n"; + return builder.create(loc, resultType, input); + } + + llvm::outs() << "No suitable cast operation for target materialization\n"; + return nullptr; + }); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + bool isLegal = isLegalForUnrankedTypeConversion(op, typeConverter); + llvm::outs() << "Operation "; + op->getName().print(llvm::outs()); + llvm::outs() << " is " << (isLegal ? "legal" : "illegal") << "\n"; + return isLegal; + }); + + RewritePatternSet patterns(ctx); + populateTypeConversionPatterns1(patterns, typeConverter, target); + patterns.add(typeConverter, ctx); + + // 添加标准函数转换模式 + // populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); + // populateCallOpTypeConversionPattern(patterns, typeConverter); + // populateReturnOpTypeConversionPattern(patterns, typeConverter); + // populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + // populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); + + if (failed(applyFullConversion(module, target, std::move(patterns)))) { + llvm::outs() << "Failed to apply full conversion on module.\n"; + signalPassFailure(); + } else { + llvm::outs() << "Successfully applied full conversion on module.\n"; + } + + // 后处理:清理不必要的cast操作 + module.walk([&](memref::CastOp castOp) { + if (castOp.getOperand().getType() == castOp.getType()) { + castOp.replaceAllUsesWith(castOp.getOperand()); + castOp.erase(); + llvm::outs() << "Removed redundant cast operation\n"; + } + }); + + llvm::outs() << "After ConvertRankedToUnrankedPass on module: "; + getOperation()->print(llvm::outs()); + llvm::outs() << "\nFinished ConvertRankedToUnrankedPass on module.\n"; + } + + StringRef getArgument() const final { return "convert-ranked-to-unranked"; } + StringRef getDescription() const final { + return "Convert memref types with dynamic shape to single-dimensional dynamic ones"; + } +}; + +inline std::unique_ptr createConvertRankedToUnrankedPass() { + return std::make_unique(); +} + +} // namespace npu +} // namespace dicp +} // namespace mlir + +#endif // CONVERT_RANKED_TO_UNRANKED_PASS_H \ No newline at end of file diff --git a/compiler/include/dicp/Conversion/LinalgToNPU/VerifyNoLinalgGenericPass.hpp b/compiler/include/dicp/Conversion/LinalgToNPU/VerifyNoLinalgGenericPass.hpp new file mode 100644 index 00000000..8a9e092d --- /dev/null +++ b/compiler/include/dicp/Conversion/LinalgToNPU/VerifyNoLinalgGenericPass.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include "mlir/Pass/Pass.h" +#include +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/SmallVector.h" +#include + + +using namespace mlir; +using namespace mlir::func; + +namespace mlir { +namespace dicp { +namespace npu { + +struct VerifyNoLinalgGenericPass : public PassWrapper> { + StringRef getArgument() const final { return "verify-no-linalg-generic"; } + StringRef getDescription() const final { + return "Verify that no 'linalg.generic' operations exist"; + } + + void runOnOperation() override { + bool foundGeneric = false; + getOperation()->walk([&](linalg::GenericOp op) { + op.emitError() << "linalg.generic is not allowed in this pass pipeline."; + foundGeneric = true; + }); + + if (foundGeneric) { + signalPassFailure(); + } + } +}; + +inline std::unique_ptr createVerifyNoLinalgGenericPass() { + return std::make_unique(); +} + +} // namespace npu +} // namespace dicp +} // namespace mlir + diff --git a/compiler/lib/Conversion/LinalgToNPU/CMakeLists.txt b/compiler/lib/Conversion/LinalgToNPU/CMakeLists.txt index e48b1cba..cd36bb50 100644 --- a/compiler/lib/Conversion/LinalgToNPU/CMakeLists.txt +++ b/compiler/lib/Conversion/LinalgToNPU/CMakeLists.txt @@ -1,6 +1,8 @@ add_triton_library(LinalgToNPU LinalgToNPUPass.cpp LinalgToNPU.cpp + ConvertRankedToUnrankedPass.cpp + VerifyNoLinalgGenericPass.cpp DEPENDS DICPNPUIncGen diff --git a/compiler/lib/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.cpp b/compiler/lib/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.cpp new file mode 100644 index 00000000..89d634bc --- /dev/null +++ b/compiler/lib/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.cpp @@ -0,0 +1,40 @@ +#include "dicp/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.hpp" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +// #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "mlir/Pass/PassManager.h" +#include "dicp/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.hpp" + + +using namespace mlir; +using namespace mlir::func; + +namespace mlir::dicp::npu { + +// // 定义 ConvertRankedToUnrankedPass +// struct ConvertRankedToUnrankedPass +// : public PassWrapper> { + + + +// }; + + +// std::unique_ptr createConvertRankedToUnrankedPass() { +// return std::make_unique(); +// } + + +} // namespace mlir::dicp::npu diff --git a/compiler/lib/Conversion/LinalgToNPU/LinalgToNPU.cpp b/compiler/lib/Conversion/LinalgToNPU/LinalgToNPU.cpp index 0a34b966..356f46b1 100644 --- a/compiler/lib/Conversion/LinalgToNPU/LinalgToNPU.cpp +++ b/compiler/lib/Conversion/LinalgToNPU/LinalgToNPU.cpp @@ -31,4 +31,5 @@ void npu::populateLinalgToNPUConversionPatterns(RewritePatternSet &patterns) { // patterns.add(patterns.getContext()); // patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + // patterns.add(patterns.getContext()); } diff --git a/compiler/lib/Conversion/LinalgToNPU/LinalgToNPUPass.cpp b/compiler/lib/Conversion/LinalgToNPU/LinalgToNPUPass.cpp index 760ec2c1..9773d830 100644 --- a/compiler/lib/Conversion/LinalgToNPU/LinalgToNPUPass.cpp +++ b/compiler/lib/Conversion/LinalgToNPU/LinalgToNPUPass.cpp @@ -1,5 +1,7 @@ #include "dicp/Conversion/LinalgToNPU/LinalgToNPU.h" #include "dicp/Conversion/LinalgToNPU/AddWorkspaceAndAttrsPass.hpp" +#include "dicp/Conversion/LinalgToNPU/ConvertRankedToUnrankedPass.hpp" +#include "dicp/Conversion/LinalgToNPU/VerifyNoLinalgGenericPass.hpp" #include "dicp/Dialect/NPU/IR/NPUDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -241,6 +243,18 @@ class LinalgToNPUPass : public LinalgToNPUBase { [&](func::FuncOp func) { convertTTFunc(func, existDot); }); std::cout << "Function header and footer conversion completed successfully.\n"; + + PassManager pm(context); + pm.addPass(mlir::dicp::npu::createVerifyNoLinalgGenericPass()); + if (failed(pm.run(moduleOp))) { + signalPassFailure(); + } + // 使用 PassManager 执行 ConvertRankedToUnrankedPass,但是华为不支持*x的形式,只能在python层字符串替换。尽力了 + // pm.addPass(mlir::dicp::npu::createConvertRankedToUnrankedPass()); + // if (failed(pm.run(moduleOp))) { + // signalPassFailure(); + // } + std::cout << "Adding workspace argument to functions...\n"; // 新增功能逻辑:强制在函数参数开头添加一个参数,代表工作空间的占位参数 for (auto func : getOperation().getOps()) { @@ -263,6 +277,8 @@ class LinalgToNPUPass : public LinalgToNPUBase { } } + + }; } // namespace diff --git a/compiler/lib/Conversion/LinalgToNPU/VerifyNoLinalgGenericPass.cpp b/compiler/lib/Conversion/LinalgToNPU/VerifyNoLinalgGenericPass.cpp new file mode 100644 index 00000000..a30291a8 --- /dev/null +++ b/compiler/lib/Conversion/LinalgToNPU/VerifyNoLinalgGenericPass.cpp @@ -0,0 +1,7 @@ +#include "dicp/Conversion/LinalgToNPU/VerifyNoLinalgGenericPass.hpp" + +using namespace mlir; +using namespace mlir::func; + +namespace mlir::dicp::npu { +} // namespace mlir::dicp::npu