Skip to content

Commit

Permalink
Merge pull request #838 from WoutLegiest:dev-boolvec
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 659578929
  • Loading branch information
copybara-github committed Aug 5, 2024
2 parents 522b876 + 9d893a6 commit ac2ef64
Show file tree
Hide file tree
Showing 14 changed files with 466 additions and 12 deletions.
1 change: 1 addition & 0 deletions lib/Dialect/CGGI/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/CGGI/IR/CGGIAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@

include "lib/Dialect/CGGI/IR/CGGIDialect.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/CommonAttrConstraints.td"

class CGGI_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<CGGI_Dialect, name, traits> {
let mnemonic = attrMnemonic;
let assemblyFormat = "`<` struct(params) `>`";
}

def CGGI_CGGIParams : AttrDef<CGGI_Dialect, "CGGIParams"> {
let mnemonic = "cggi_params";
Expand All @@ -22,4 +30,26 @@ def CGGI_CGGIParams : AttrDef<CGGI_Dialect, "CGGIParams"> {
let assemblyFormat = "`<` struct(params) `>`";
}

def CGGIGate_Attr : CGGI_Attr<"CGGIGate", "cggi_gate"> {
let summary = "An Attribute containing an array of strings to store bool gates";

let description = [{
This attributes stores a list of string identifiers for Boolean gates.

This used in the `cggi.packed` operation to indicate the boolean gate that applies pairwise between elements of two ciphertext arrays. For example,

%0 = cggi.packed %a, %b {gates = #cggi.cggi_gate<"and", "xor">} : tensor<2x!lwe.lwe_ciphertext>

applies an "and" gate between the first elements of %a and %b and an xor gate between the second elements.
}];

let parameters = (ins
ArrayRefParameter<"mlir::StringAttr">: $gate);

let assemblyFormat = "`<` $gate `>`";
}




#endif // LIB_DIALECT_CGGI_IR_CGGIATTRIBUTES_TD_
3 changes: 2 additions & 1 deletion lib/Dialect/CGGI/IR/CGGIDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

include "mlir/IR/DialectBase.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/CommonAttrConstraints.td"


def CGGI_Dialect : Dialect {
let name = "cggi";
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/CGGI/IR/CGGIOps.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#ifndef LIB_DIALECT_CGGI_IR_CGGIOPS_H_
#define LIB_DIALECT_CGGI_IR_CGGIOPS_H_

#include "lib/Dialect/CGGI/IR/CGGIAttributes.h"
#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/HEIRInterfaces.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

#define GET_OP_CLASSES
Expand Down
15 changes: 15 additions & 0 deletions lib/Dialect/CGGI/IR/CGGIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

include "lib/Dialect/HEIRInterfaces.td"
include "lib/Dialect/CGGI/IR/CGGIDialect.td"
include "lib/Dialect/CGGI/IR/CGGIAttributes.td"

include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.td"
include "lib/Dialect/LWE/IR/LWETypes.td"

include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

Expand Down Expand Up @@ -57,6 +59,19 @@ def CGGI_NotOp : CGGI_Op<"not", [
let summary = "Logical NOT of two ciphertexts";
}

def CGGI_PackedOp : CGGI_Op<"packed_gates", [
Pure,
SameOperandsAndResultType
]> {
let arguments = (ins
CGGIGate_Attr:$gates,
LWECiphertextLike:$lhs,
LWECiphertextLike:$rhs
);

let results = (outs LWECiphertextLike:$output);
}

class CGGI_LutOp<string mnemonic, list<Trait> traits = []>
: CGGI_Op<mnemonic, traits # [
Pure,
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/CGGI/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cc_library(
"Passes.h",
],
deps = [
":BooleanLineVectorizer",
":SetDefaultParameters",
":pass_inc_gen",
"@heir//lib/Dialect/CGGI/IR:Dialect",
Expand All @@ -34,6 +35,29 @@ cc_library(
],
)

cc_library(
name = "BooleanLineVectorizer",
srcs = ["BooleanLineVectorizer.cpp"],
hdrs = [
"BooleanLineVectorizer.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/CGGI/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Graph",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
Expand Down
225 changes: 225 additions & 0 deletions lib/Dialect/CGGI/Transforms/BooleanLineVectorizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#include "lib/Dialect/CGGI/Transforms/BooleanLineVectorizer.h"

#include <string>

#include "lib/Dialect/CGGI/IR/CGGIAttributes.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Graph/Graph.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/TopologicalSortUtils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

#define DEBUG_TYPE "bool-line-vectorizer"

namespace mlir {
namespace heir {
namespace cggi {

#define GEN_PASS_DEF_BOOLEANLINEVECTORIZER
#include "lib/Dialect/CGGI/Transforms/Passes.h.inc"

bool areCompatibleBool(Operation *lhs, Operation *rhs) {
if (lhs->getDialect() != rhs->getDialect() ||
lhs->getResultTypes() != rhs->getResultTypes() ||
lhs->getAttrs() != rhs->getAttrs()) {
return false;
}
// TODO: Check if can be made better with a BooleanPackableGate trait
// on the CGGI_BinaryGateOp's?
return OpTrait::hasElementwiseMappableTraits(lhs);
}

bool tryBoolVectorizeBlock(Block *block, MLIRContext &context) {
graph::Graph<Operation *> graph;
for (auto &op : block->getOperations()) {
if (!op.hasTrait<OpTrait::Elementwise>()) {
continue;
}

graph.addVertex(&op);
SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
options.omitBlockArguments = true;

getBackwardSlice(&op, &backwardSlice, options);
for (auto *upstreamDep : backwardSlice) {
// An edge from upstreamDep to `op` means that upstreamDep must be
// computed before `op`.
graph.addEdge(upstreamDep, &op);
}
}

if (graph.empty()) {
return false;
}

auto result = graph.sortGraphByLevels();
assert(succeeded(result) &&
"Only possible failure is a cycle in the SSA graph!");
auto levels = result.value();

LLVM_DEBUG({
llvm::dbgs()
<< "Found operations to vectorize. In topo-sorted level order:\n";
int level_num = 0;
for (const auto &level : levels) {
llvm::dbgs() << "\nLevel " << level_num++ << ":\n";
for (auto op : level) {
llvm::dbgs() << " - " << *op << "\n";
}
}
});

bool madeReplacement = false;
for (const auto &level : levels) {
DenseMap<Operation *, SmallVector<Operation *, 4>> compatibleOps;
for (auto *op : level) {
bool foundCompatible = false;
for (auto &[key, bucket] : compatibleOps) {
if (areCompatibleBool(key, op)) {
compatibleOps[key].push_back(op);
foundCompatible = true;
}
}
if (!foundCompatible) {
compatibleOps[op].push_back(op);
}
}
LLVM_DEBUG(llvm::dbgs()
<< "Partitioned level of size " << level.size() << " into "
<< compatibleOps.size() << " groups of compatible ops\n");

// Loop over all the compatibleOp groups
// Each loop will have the key and a bucket with all the operations in
for (auto &[key, bucket] : compatibleOps) {
if (bucket.size() < 2) {
continue;
}

LLVM_DEBUG({
llvm::dbgs() << "[START] Bucket \t Vectorizing ops:\n";
for (auto op : bucket) {
llvm::dbgs() << " - " << *op << "\n";
}
});

OpBuilder builder(bucket.back());
// relies on CGGI ops having a single result type
Type elementType = key->getResultTypes()[0];
RankedTensorType tensorType = RankedTensorType::get(
{static_cast<int64_t>(bucket.size())}, elementType);

SmallVector<Value, 4> vectorizedOperands;
SmallVector<StringAttr, 4> vectorizedGateOperands;

for (auto *op : bucket) {
std::string str;
if (isa<cggi::AndOp>(op)) {
str = "and";
} else if (isa<cggi::NandOp>(op)) {
str = "nand";
} else if (isa<cggi::XorOp>(op)) {
str = "xor";
} else if (isa<cggi::XNorOp>(op)) {
str = "xnor";
} else if (isa<cggi::OrOp>(op)) {
str = "or";
} else if (isa<cggi::NorOp>(op)) {
str = "nor";
} else {
LLVM_DEBUG(llvm::dbgs()
<< "Try to parse boolean operation that does not exist.");
}
vectorizedGateOperands.push_back(StringAttr::get(&context, str));
}

// Group the independent operands over the operations
for (uint operandIndex = 0; operandIndex < key->getNumOperands();
++operandIndex) {
SmallVector<Value, 4> operands;
LLVM_DEBUG({
llvm::dbgs() << "For: " << key->getName()
<< " Number of ops: " << key->getNumOperands() << "\n";
});

operands.reserve(bucket.size());
///------------------------------------------
for (auto *op : bucket) {
LLVM_DEBUG(llvm::dbgs() << "getOperand for [" << operandIndex << "]: "
<< op->getOperand(operandIndex) << "\n");
operands.push_back(op->getOperand(operandIndex));
}
///------------------------------------------

auto fromElementsOp = builder.create<tensor::FromElementsOp>(
key->getLoc(), tensorType, operands);
vectorizedOperands.push_back(fromElementsOp.getResult());
}

LLVM_DEBUG({
llvm::dbgs() << "Go over vectorizedOps:\n";
for (auto op : vectorizedOperands) {
llvm::dbgs() << " - " << op << "\n";
}
llvm::dbgs() << "Go over vectorizedGateOps:\n";
for (auto op : vectorizedGateOperands) {
llvm::dbgs() << " - " << op;
}
llvm::dbgs() << "\n";
});

auto oplist = CGGIGateAttr::get(&context, vectorizedGateOperands);

auto vectorizedOp = builder.create<cggi::PackedOp>(
key->getLoc(), tensorType, oplist, vectorizedOperands[0],
vectorizedOperands[1]);

int bucketIndex = 0;

for (auto *op : bucket) {
auto extractionIndex = builder.create<arith::ConstantOp>(
op->getLoc(), builder.getIndexAttr(bucketIndex));
auto extractOp = builder.create<tensor::ExtractOp>(
op->getLoc(), elementType, vectorizedOp->getResult(0),
extractionIndex.getResult());
op->replaceAllUsesWith(ValueRange{extractOp.getResult()});
bucketIndex++;
}

for (auto *op : bucket) {
op->erase();
}
madeReplacement = true;
}
}

return madeReplacement;
}

struct BooleanLineVectorizer
: impl::BooleanLineVectorizerBase<BooleanLineVectorizer> {
using BooleanLineVectorizerBase::BooleanLineVectorizerBase;

void runOnOperation() override {
MLIRContext &context = getContext();

getOperation()->walk<WalkOrder::PreOrder>([&](Block *block) {
if (tryBoolVectorizeBlock(block, context)) {
sortTopologically(block);
}
});
}
};

} // namespace cggi
} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions lib/Dialect/CGGI/Transforms/BooleanLineVectorizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef LIB_TRANSFORMS_CGGI_BOOLEANLINEVECTORIZER_H_
#define LIB_TRANSFORMS_CGGI_BOOLEANLINEVECTORIZER_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace cggi {

#define GEN_PASS_DECL_BOOLEANLINEVECTORIZER
#include "lib/Dialect/CGGI/Transforms/Passes.h.inc"

} // namespace cggi
} // namespace heir
} // namespace mlir

#endif // LIB_TRANSFORMS_CGGI_BOOLEANLINEVECTORIZER_H_
1 change: 1 addition & 0 deletions lib/Dialect/CGGI/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LIB_DIALECT_CGGI_TRANSFORMS_PASSES_H_

#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/CGGI/Transforms/BooleanLineVectorizer.h"
#include "lib/Dialect/CGGI/Transforms/SetDefaultParameters.h"

namespace mlir {
Expand Down
Loading

0 comments on commit ac2ef64

Please sign in to comment.