Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def ttsharedir_to_dicp(mod, metadata, opt, *, named_ops=False):
# content = content.replace('func.func @_silu_and_mul_kernel(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32)',
# 'func.func @_silu_and_mul_kernel(%arg1000: memref<?xi8>, %arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) attributes {WorkspaceArgIdx = 0 : i64, global_kernel = "local", mix_mode = "aiv"}')
# 将"*xf16"替换成"?xf16"
content = content.replace("*xf16", "?xf16")
content = content.replace("*xf", "?xf")
print(f"zmz debug: after replace content: {content}")
# 将context 写回去
with open(dst_path, 'w') as f:
Expand Down
188 changes: 187 additions & 1 deletion compiler/include/dicp/Conversion/LinalgToNPU/ConversionPatterns.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/IR/BuiltinAttributes.h"
// #include "mlir/Dialect/Arith/IR/ArithAttributes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
// #include "mlir/Dialect/Linalg/IR/LinalgAttributes.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"

#include <iostream>
#include <numeric>
Expand All @@ -19,6 +28,7 @@

using namespace mlir;
using namespace dicp;
using namespace mlir::utils;

namespace {

Expand Down Expand Up @@ -331,4 +341,180 @@ struct ConvertLinalgGenericToArith
};


} // namespace
// ... 已有代码 ...


struct ConvertLinalgGenericToBroadcast : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
std::cout << "[ConvertLinalgGenericToBroadcast] Starting matchAndRewrite for LinalgGenericOp at location: ";
genericOp.getLoc().print(llvm::outs());
std::cout << "\n";

// === Step 1: 检查是否是一个合法的 broadcast 模式 ===
if (!isBroadcastPattern(genericOp)) {
std::cout << "[INFO] Not a broadcast pattern\n";
return failure();
}
std::cout << "[INFO] Detected broadcast pattern\n";

// === Step 2: 获取输入输出 Tensor ===
Value input = genericOp.getDpsInputOperand(0)->get();
Value output = genericOp.getDpsInitOperand(0)->get();

// 提取 broadcastDims 属性
auto attr = genericOp->getAttr("broadcastDims");
if (!attr) {
std::cout << "[ERROR] Missing 'broadcastDims' attribute\n";
return failure();
}

std::cout << "[DEBUG] Raw broadcastDims attribute value: ";
attr.print(llvm::outs());
std::cout << "\n";
SmallVector<int64_t> broadcastDims;
if (auto denseAttr = dyn_cast<DenseI64ArrayAttr>(attr)) {
// broadcastDims = denseAttr.asArrayRef().vec();
broadcastDims.assign(denseAttr.asArrayRef().begin(), denseAttr.asArrayRef().end());
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
if (auto intAttr = dyn_cast<IntegerAttr>(element)) {
broadcastDims.push_back(intAttr.getInt());
} else {
std::cout << "[ERROR] Invalid element in 'broadcastDims' array\n";
return failure();
}
}
} else {
std::cout << "[ERROR] Invalid 'broadcastDims' attribute type\n";
return failure();
}

std::cout << "[INFO] Detected broadcastDims = [";
for (int64_t d : broadcastDims) {
std::cout << d << " ";
}
std::cout << "]\n";

// 创建 linalg.broadcast 操作
rewriter.setInsertionPoint(genericOp);
auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
genericOp.getLoc(),
input,
output,
broadcastDims
);

// 替换原操作
// 打印操作
std::cout << "[DEBUG] Replacing linalg.generic with linalg.broadcast\n";
std::cout << "[DEBUG] Before replacement:\n";
genericOp.print(llvm::outs());
std::cout << "\n";
// 打印替换后的操作
std::cout << "[DEBUG] After replacement:\n";
broadcastOp.print(llvm::outs());
std::cout << "\n";
rewriter.replaceOp(genericOp, broadcastOp.getResult());
std::cout << "[INFO] Replaced linalg.generic with linalg.broadcast\n";
return success();
}

private:
// 判断是否为 broadcast 模式
bool isBroadcastPattern(linalg::GenericOp op) const {
std::cout << "[DEBUG] Checking broadcast pattern for linalg.generic at location: ";
op.getLoc().print(llvm::outs());
std::cout << "\n";

// 1. 检查迭代器类型
SmallVector<StringRef> iterTypes;
if (failed(getIteratorTypeNames(op, iterTypes))) {
std::cout << "[DEBUG] Failed to get iterator type names. Not a broadcast pattern.\n";
return false;
}
if (!llvm::all_of(iterTypes, [](StringRef type) { return type == "parallel"; })) {
std::cout << "[DEBUG] Iterator types are not all 'parallel'.\n";
return false;
}
std::cout << "[DEBUG] All iterator types are 'parallel'.\n";

// 2. 检查输入输出数量
if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1) {
std::cout << "[DEBUG] Expected 1 input and 1 output.\n";
return false;
}
std::cout << "[DEBUG] Number of inputs and outputs matched.\n";

// 检查 block 结构
// Region &region = op->getRegion(0);
if (op->getNumRegions() == 0) {
std::cout << "[ERROR] Operation has no regions.\n";
return false;
}

Region *region = &op->getRegion(0);
if (!region || !region->hasOneBlock()) {
std::cout << "[ERROR] Region not valid or does not have one block.\n";
return false;
}
std::cout << "[DEBUG] Region has exactly one block.\n";

Block &block = region->front();
if (block.empty() || !isa<linalg::YieldOp>(block.back())) {
std::cout << "[DEBUG] Block does not end with linalg.yield.\n";
return false;
}
std::cout << "[DEBUG] Block is not empty and ends with linalg.yield.\n";

Operation *innerOp = block.getTerminator()->getPrevNode();
if (innerOp && !isa<linalg::YieldOp>(innerOp)) {
std::cout << "[DEBUG] Expected only linalg.yield in block.\n";
return false;
}
std::cout << "[DEBUG] No inner operation before yield.\n";

auto yieldOp = cast<linalg::YieldOp>(block.getTerminator());
if (yieldOp->getNumOperands() != 1 || yieldOp->getOperand(0) != block.getArgument(0)) {
std::cout << "[DEBUG] Yield operand mismatch.\n";
return false;
}
std::cout << "[DEBUG] Yield operand check passed.\n";

std::cout << "[DEBUG] Inner operation and yield check passed.\n";

// 4. 检查是否有 broadcastDims 属性
if (!op->hasAttr("broadcastDims")) {
std::cout << "[DEBUG] Missing 'broadcastDims' attribute.\n";
return false;
}

std::cout << "[DEBUG] 'broadcastDims' attribute found. This is a broadcast pattern.\n";
return true;
}

// 获取 Iterator 类型名称
LogicalResult getIteratorTypeNames(linalg::GenericOp op,
SmallVectorImpl<StringRef> &types) const {
auto iteratorAttrs = op.getIteratorTypes().getValue();
for (Attribute attr : iteratorAttrs) {
if (auto iterTypeAttr = dyn_cast<linalg::IteratorTypeAttr>(attr)) {
types.push_back(mlir::utils::stringifyIteratorType(iterTypeAttr.getValue()));
} else if (auto strAttr = dyn_cast<StringAttr>(attr)) {
types.push_back(strAttr.getValue());
} else {
std::cout << "[ERROR] Unsupported iterator type attribute.\n";
return failure();
}
}
return success();
}
};

// ... 已有代码 ...



} // namespace
Loading