Skip to content

Commit

Permalink
Bump llvm-project to 6b65d79 (#2723)
Browse files Browse the repository at this point in the history
Co-authored-by: hanhanW <[email protected]>
  • Loading branch information
Groverkss and hanhanW authored Jan 4, 2024
1 parent aa7e95f commit fb1dfa3
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 14 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 1596 files
21 changes: 9 additions & 12 deletions lib/Dialect/TMTensor/IR/TMTensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
})
->getResult(0);
b.create<memref::StoreOp>(loc, sum, output, localIVs);
b.create<scf::YieldOp>(loc);
});
}

Expand Down Expand Up @@ -229,13 +228,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
SmallVector<Value>(weightRank, one), init,
[&](OpBuilder &b, Location loc, ValueRange localIVs,
ValueRange accs) {
b.create<scf::ReduceOp>(
loc, init,
[&](OpBuilder &b, Location loc, Value elem, Value acc) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value max = b.create<arith::MaximumFOp>(loc, x, acc);
b.create<scf::ReduceReturnOp>(loc, max);
});
auto reduceOp = b.create<scf::ReduceOp>(loc, init);
// Build reduce body.
Block &reductionBody = reduceOp.getReductions()[0].front();
auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody);
Value acc = reductionBody.getArgument(0);
Value x =
bodyBuilder.create<memref::LoadOp>(loc, weight, localIVs);
Value max = bodyBuilder.create<arith::MaximumFOp>(loc, x, acc);
bodyBuilder.create<scf::ReduceReturnOp>(loc, max);
})
.getResult(0);
// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
Expand All @@ -247,7 +248,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
x = b.create<arith::SubFOp>(loc, x, globalMax);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});
// calculate exp(weight)
SmallVector<Value> min(weightRank, zero),
Expand All @@ -258,7 +258,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<math::ExpOp>(loc, x);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});
Value expWeightSum = b.create<memref::AllocOp>(
loc,
Expand Down Expand Up @@ -290,7 +289,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value y = b.create<memref::LoadOp>(loc, weight, coords);
Value sum = b.create<arith::AddFOp>(loc, x, y);
b.create<memref::StoreOp>(loc, sum, expWeightSum, outsideDims);
b.create<scf::YieldOp>(loc);
});
});
// calculate exp(weight) / sum(exp(weight))
Expand All @@ -305,7 +303,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
x = b.create<arith::DivFOp>(loc, x, sum);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
b.create<scf::YieldOp>(loc);
});

// output = weight @ value
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,8 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
if (getOperand().getType() != getResult().getType())
return nullptr;
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand();
Expand All @@ -727,6 +729,8 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (getOperand(0).getType() != getResult().getType())
return nullptr;
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
Expand All @@ -739,6 +743,8 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
if (getSelf().getType() != getResult().getType())
return nullptr;
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
return getSelf();
Expand Down Expand Up @@ -911,6 +917,8 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
auto resType = getType().dyn_cast<BaseTensorType>();
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
return nullptr;
if (inputType != resType)
return nullptr;
// Fold when both the input tensor and result are unity rank tensors.
return getOperand(0);
}
Expand Down Expand Up @@ -2441,6 +2449,8 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
if (list.getElements()[0].getType() != getResult().getType())
return nullptr;
return list.getElements()[0];
}

Expand All @@ -2451,6 +2461,8 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (inType != outType)
return nullptr;
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
Expand Down Expand Up @@ -2480,6 +2492,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {

auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (inType != outType)
return nullptr;
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
class InlineGlobalSlotsAnalysisState : public AnalysisState {
public:
InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
setSafe();
(void)setSafe();
}

void print(raw_ostream &os) const override {
Expand Down

0 comments on commit fb1dfa3

Please sign in to comment.