Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/mlir/cxx/mlir/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cxx/mlir/codegen.h>

// cxx
#include <cxx/ast_interpreter.h>
#include <cxx/control.h>
#include <cxx/external_name_encoder.h>
#include <cxx/symbols.h>
Expand Down Expand Up @@ -79,6 +80,9 @@ auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional<mlir::Value> {
auto var = symbol_cast<VariableSymbol>(symbol);
if (!var) return std::nullopt;

if (var->isStatic()) return std::nullopt;
if (!var->parent()->isBlock()) return std::nullopt;

auto type = convertType(var->type());
auto ptrType = builder_.getType<mlir::cxx::PointerType>(type);

Expand Down Expand Up @@ -172,6 +176,73 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol)
return func;
}

auto Codegen::findOrCreateGlobal(VariableSymbol* variableSymbol)
-> mlir::cxx::GlobalOp {
if (auto it = globalOps_.find(variableSymbol); it != globalOps_.end()) {
return it->second;
}

auto varType = convertType(variableSymbol->type());

const auto loc = getLocation(variableSymbol->location());

auto guard = mlir::OpBuilder::InsertionGuard(builder_);

builder_.setInsertionPointToStart(module_.getBody());

mlir::cxx::InlineKind inlineKind = mlir::cxx::InlineKind::NoInline;

mlir::cxx::LinkageKind linkageKind = mlir::cxx::LinkageKind::External;

if (variableSymbol->isStatic()) {
linkageKind = mlir::cxx::LinkageKind::Internal;
}

auto linkageAttr =
mlir::cxx::LinkageKindAttr::get(builder_.getContext(), linkageKind);

std::string name;

name = to_string(variableSymbol->name());

llvm::SmallVector<mlir::Type> resultTypes;
resultTypes.push_back(varType);

mlir::Attribute initializer;

auto value = variableSymbol->constValue();

if (value.has_value()) {
auto interp = ASTInterpreter{unit_};

if (control()->is_integral_or_unscoped_enum(variableSymbol->type())) {
auto constValue = interp.toInt(*value);
initializer = builder_.getI64IntegerAttr(constValue.value_or(0));
} else if (control()->is_floating_point(variableSymbol->type())) {
auto ty = control()->remove_cv(variableSymbol->type());
if (type_cast<FloatType>(ty)) {
auto constValue = interp.toFloat(*value);
initializer = builder_.getF32FloatAttr(constValue.value_or(0));
} else if (type_cast<DoubleType>(ty)) {
auto constValue = interp.toDouble(*value);
initializer = builder_.getF64FloatAttr(constValue.value_or(0));
}
}
}

if (!initializer) {
// default initialize to zero
initializer = builder_.getZeroAttr(varType);
}

auto var = mlir::cxx::GlobalOp::create(builder_, loc, varType, false, name,
initializer);

globalOps_.insert_or_assign(variableSymbol, var);

return var;
}

auto Codegen::getLocation(SourceLocation location) -> mlir::Location {
auto [filename, line, column] = unit_->tokenStartPosition(location);

Expand Down
4 changes: 4 additions & 0 deletions src/mlir/cxx/mlir/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ class Codegen {
[[nodiscard]] auto findOrCreateFunction(FunctionSymbol* functionSymbol)
-> mlir::cxx::FuncOp;

[[nodiscard]] auto findOrCreateGlobal(VariableSymbol* var)
-> mlir::cxx::GlobalOp;

[[nodiscard]] auto newTemp(const Type* type, SourceLocation loc)
-> mlir::cxx::AllocaOp;

Expand Down Expand Up @@ -328,6 +331,7 @@ class Codegen {
std::unordered_map<ClassSymbol*, mlir::Type> classNames_;
std::unordered_map<Symbol*, mlir::Value> locals_;
std::unordered_map<FunctionSymbol*, mlir::cxx::FuncOp> funcOps_;
std::unordered_map<VariableSymbol*, mlir::cxx::GlobalOp> globalOps_;
std::unordered_map<std::string_view, int> uniqueSymbolNames_;
std::unordered_map<const StringLiteral*, mlir::StringAttr> stringLiterals_;
Loop loop_;
Expand Down
14 changes: 14 additions & 0 deletions src/mlir/cxx/mlir/codegen_declarations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ auto Codegen::DeclarationVisitor::operator()(SimpleDeclarationAST* ast)
-> DeclarationResult {
if (!gen.function_) {
// skip for now, as we only look for local variable declarations

for (auto node : ListView{ast->initDeclaratorList}) {
auto var = symbol_cast<VariableSymbol>(node->symbol);
if (!var) continue;

auto glo = gen.findOrCreateGlobal(var);
if (!glo) {
gen.unit_->error(node->initializer->firstSourceLocation(),
std::format("cannot create global variable '{}'",
to_string(var->name())));
continue;
}
}

return {};
}

Expand Down
23 changes: 18 additions & 5 deletions src/mlir/cxx/mlir/codegen_expressions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,23 @@ auto Codegen::ExpressionVisitor::operator()(IdExpressionAST* ast)
}
}

if (auto var = symbol_cast<VariableSymbol>(ast->symbol)) {
if (auto it = gen.globalOps_.find(var); it != gen.globalOps_.end()) {
auto loc = gen.getLocation(ast->firstSourceLocation());

auto ptrToVoidType = mlir::cxx::PointerType::get(
gen.builder_.getContext(), gen.convertType(var->type()));

auto resultType =
mlir::cxx::PointerType::get(gen.builder_.getContext(), ptrToVoidType);

auto op = mlir::cxx::AddressOfOp::create(gen.builder_, loc, resultType,
it->second.getSymName());

return {op};
}
}

auto op =
gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind()));

Expand Down Expand Up @@ -694,9 +711,7 @@ auto Codegen::ExpressionVisitor::operator()(PostIncrExpressionAST* ast)

if (control()->is_integral_or_unscoped_enum(ast->baseExpression->type)) {
auto loc = gen.getLocation(ast->firstSourceLocation());
auto ptrTy =
mlir::cast<mlir::cxx::PointerType>(expressionResult.value.getType());
auto elementTy = ptrTy.getElementType();
auto elementTy = gen.convertType(ast->baseExpression->type);
auto loadOp = mlir::cxx::LoadOp::create(gen.builder_, loc, elementTy,
expressionResult.value);
auto resultTy = gen.convertType(ast->baseExpression->type);
Expand Down Expand Up @@ -944,8 +959,6 @@ auto Codegen::ExpressionVisitor::operator()(UnaryExpressionAST* ast)
}

if (control()->is_floating_point(ast->type)) {
resultType.dump();

mlir::FloatAttr value;
switch (ast->type->kind()) {
case TypeKind::kFloat:
Expand Down
5 changes: 5 additions & 0 deletions src/mlir/cxx/mlir/codegen_units.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ struct Codegen::UnitVisitor {

void operator()(NamespaceSymbol* symbol) {
for (auto member : views::members(symbol)) {
if (auto var = symbol_cast<VariableSymbol>(member)) {
p.gen.findOrCreateGlobal(var);
continue;
}

visit(*this, member);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/mlir/cxx/mlir/cxx_dialect_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class GlobalOpLowering : public OpConversionPattern<cxx::GlobalOp> {

rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
op, elementType, op.getConstant(), LLVM::linkage::Linkage::Private,
op.getSymName(), adaptor.getValue().value());
op.getSymName(), adaptor.getValueAttr());

return success();
}
Expand Down
4 changes: 4 additions & 0 deletions src/parser/cxx/lexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ auto Lexer::readToken() -> TokenKind {
(LA(1) == '+' || LA(1) == '-')) {
consume(2);
integer_literal = false;
} else if (pos_ + 1 < end_ && (ch == 'e' || ch == 'E') &&
std::isdigit(LA(1))) {
consume(1);
integer_literal = false;
} else if (pos_ + 1 < end_ && ch == '\'' && is_idcont(LA(1))) {
consume();
} else if (is_idcont(ch)) {
Expand Down
15 changes: 5 additions & 10 deletions src/parser/cxx/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5659,17 +5659,12 @@ auto Parser::parse_init_declarator(InitDeclaratorAST*& yyast,
control()->remove_cv(var->type()));
}

if (var->isConstexpr()) {
if (!var->initializer()) {
parse_error(var->location(), "constexpr variable must be initialized");
} else {
var->setConstValue(evaluate_constant_expression(var->initializer()));
if (var->initializer()) {
var->setConstValue(evaluate_constant_expression(var->initializer()));
}

if (!var->constValue().has_value()) {
type_error(var->location(),
"initializer of constexpr variable is not a constant");
}
}
if (var->isConstexpr() && !var->constValue().has_value()) {
parse_error(var->location(), "constexpr variable must be initialized");
}
}

Expand Down
Loading