Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ add_library(graphalg_lib STATIC
src/graphalg/GraphAlgSetDimensions.cpp
src/graphalg/GraphAlgSplitAggregate.cpp
src/graphalg/GraphAlgToCore.cpp
src/graphalg/GraphAlgToCorePipeline.cpp
src/graphalg/GraphAlgTypes.cpp
src/graphalg/GraphAlgVerifyDimensions.cpp
src/graphalg/SemiringTypes.cpp
src/graphalg/evaluate/Evaluator.cpp
)
target_include_directories(graphalg_lib PUBLIC include)
target_include_directories(graphalg_lib SYSTEM PUBLIC ${PROJECT_BINARY_DIR}/include)
Expand Down Expand Up @@ -90,15 +92,13 @@ add_library(graphalg_parse STATIC
src/graphalg/parse/Lexer.cpp
src/graphalg/parse/Parser.cpp
)
target_include_directories(graphalg_parse PUBLIC include)
# TODO: Nasty, should be provided by graphalg_lib dep
target_include_directories(graphalg_parse SYSTEM PRIVATE ${PROJECT_BINARY_DIR}/include)
target_link_libraries(graphalg_parse PRIVATE graphalg_lib)

add_executable(graphalg-translate src/graphalg-translate.cpp)
target_link_libraries(graphalg-translate PRIVATE
graphalg_parse
${llvm_libs}
graphalg_lib
graphalg_parse
MLIRTranslateLib
)

Expand All @@ -116,15 +116,21 @@ target_link_libraries(graphalg-lsp-server PRIVATE
MLIRLspServerLib
)

add_executable(graphalg-exec src/graphalg-exec.cpp)
target_link_libraries(graphalg-exec PRIVATE
${llvm_libs}
graphalg_lib
MLIRParser
)

set(ENABLE_WASM OFF CACHE BOOL "Enable wasm-only targets" FORCE)

if(ENABLE_WASM)
add_executable(wasm_parse src/wasm_parse.cpp)
target_link_libraries(wasm_parse PRIVATE
graphalg_parse
${llvm_libs}
graphalg_lib
graphalg_parse
)
target_link_options(wasm_parse PRIVATE
-sEXPORTED_FUNCTIONS=_ga_parse
Expand Down
14 changes: 14 additions & 0 deletions compiler/include/graphalg/GraphAlgAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,18 @@ def DimAttr : GraphAlg_Attr<"Dim", "dim"> {
}];
}

def MatrixAttr : GraphAlg_Attr<"Matrix", "mat"> {
let summary = "Constant value matrix";

let parameters = (ins
AttributeSelfTypeParameter<"">:$type,
"mlir::ArrayAttr":$elems);

let assemblyFormat = [{
`<` $elems `>`
}];

let genVerifyDecl = 1;
}

#endif // GRAPHALG_GRAPH_ALG_ATTR
9 changes: 9 additions & 0 deletions compiler/include/graphalg/GraphAlgPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <llvm/Support/raw_ostream.h>
#include <mlir/IR/Builders.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Pass/PassOptions.h>

#include <graphalg/GraphAlgDialect.h>
#include <graphalg/GraphAlgOps.h>
Expand Down Expand Up @@ -53,6 +55,13 @@ mlir::FailureOr<mlir::Value> createScalarOpFor(mlir::Location loc, BinaryOp op,
#define GEN_PASS_REGISTRATION
#include "graphalg/GraphAlgPasses.h.inc"

struct GraphAlgToCorePipelineOptions
: public mlir::PassPipelineOptions<GraphAlgToCorePipelineOptions> {};

void buildGraphAlgToCorePipeline(mlir::OpPassManager &pm,
const GraphAlgToCorePipelineOptions &options);
void registerGraphAlgToCorePipeline();

// Testing only:
void registerTestDensePass();

Expand Down
74 changes: 74 additions & 0 deletions compiler/include/graphalg/evaluate/Evaluator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <llvm/ADT/ArrayRef.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>

#include "graphalg/GraphAlgAttr.h"
#include "graphalg/GraphAlgCast.h"
#include "graphalg/GraphAlgTypes.h"

namespace graphalg {

/** Helper for reading elements of \c MatrixAttr. */
class MatrixAttrReader {
private:
MatrixType _type;
std::size_t _rows;
std::size_t _cols;
llvm::ArrayRef<mlir::Attribute> _elems;

public:
MatrixAttrReader(MatrixAttr attr)
: _type(llvm::cast<MatrixType>(attr.getType())),
_rows(_type.getRows().getConcreteDim()),
_cols(_type.getCols().getConcreteDim()),
_elems(attr.getElems().getValue()) {}

std::size_t nRows() const { return _rows; }
std::size_t nCols() const { return _cols; }

SemiringTypeInterface ring() const {
return llvm::cast<SemiringTypeInterface>(_type.getSemiring());
}

mlir::TypedAttr at(std::size_t row, std::size_t col) const {
assert(row < _rows);
assert(col < _cols);
return llvm::cast<mlir::TypedAttr>(_elems[row * _cols + col]);
}
};

class MatrixAttrBuilder {
private:
MatrixType _type;
SemiringTypeInterface _ring;
std::size_t _rows;
std::size_t _cols;
llvm::SmallVector<mlir::Attribute> _elems;

public:
MatrixAttrBuilder(MatrixType type)
: _type(type), _rows(_type.getRows().getConcreteDim()),
_cols(_type.getCols().getConcreteDim()),
_ring(llvm::cast<SemiringTypeInterface>(type.getSemiring())),
_elems(_rows * _cols, _ring.addIdentity()) {}

std::size_t nRows() const { return _rows; }
std::size_t nCols() const { return _cols; }

SemiringTypeInterface ring() const { return _ring; }

void set(std::size_t row, std::size_t col, mlir::TypedAttr attr) {
assert(row < _rows);
assert(col < _cols);
assert(attr.getType() == _ring);
_elems[row * _cols + col] = attr;
}

MatrixAttr build() {
auto *ctx = _type.getContext();
return MatrixAttr::get(ctx, _type, mlir::ArrayAttr::get(ctx, _elems));
}
};

MatrixAttr evaluate(mlir::func::FuncOp funcOp, llvm::ArrayRef<MatrixAttr> args);

} // namespace graphalg
Loading