diff --git a/src/mlir/cxx/mlir/codegen.cc b/src/mlir/cxx/mlir/codegen.cc index b9ec63b7..b9a29d0b 100644 --- a/src/mlir/cxx/mlir/codegen.cc +++ b/src/mlir/cxx/mlir/codegen.cc @@ -21,6 +21,7 @@ #include // cxx +#include #include #include #include @@ -79,6 +80,9 @@ auto Codegen::findOrCreateLocal(Symbol* symbol) -> std::optional { auto var = symbol_cast(symbol); if (!var) return std::nullopt; + if (var->isStatic()) return std::nullopt; + if (!var->parent()->isBlock()) return std::nullopt; + auto type = convertType(var->type()); auto ptrType = builder_.getType(type); @@ -172,6 +176,73 @@ auto Codegen::findOrCreateFunction(FunctionSymbol* functionSymbol) return func; } +auto Codegen::findOrCreateGlobal(VariableSymbol* variableSymbol) + -> mlir::cxx::GlobalOp { + if (auto it = globalOps_.find(variableSymbol); it != globalOps_.end()) { + return it->second; + } + + auto varType = convertType(variableSymbol->type()); + + const auto loc = getLocation(variableSymbol->location()); + + auto guard = mlir::OpBuilder::InsertionGuard(builder_); + + builder_.setInsertionPointToStart(module_.getBody()); + + mlir::cxx::InlineKind inlineKind = mlir::cxx::InlineKind::NoInline; + + mlir::cxx::LinkageKind linkageKind = mlir::cxx::LinkageKind::External; + + if (variableSymbol->isStatic()) { + linkageKind = mlir::cxx::LinkageKind::Internal; + } + + auto linkageAttr = + mlir::cxx::LinkageKindAttr::get(builder_.getContext(), linkageKind); + + std::string name; + + name = to_string(variableSymbol->name()); + + llvm::SmallVector resultTypes; + resultTypes.push_back(varType); + + mlir::Attribute initializer; + + auto value = variableSymbol->constValue(); + + if (value.has_value()) { + auto interp = ASTInterpreter{unit_}; + + if (control()->is_integral_or_unscoped_enum(variableSymbol->type())) { + auto constValue = interp.toInt(*value); + initializer = builder_.getI64IntegerAttr(constValue.value_or(0)); + } else if (control()->is_floating_point(variableSymbol->type())) { + auto ty = control()->remove_cv(variableSymbol->type()); + if (type_cast(ty)) { + auto constValue = interp.toFloat(*value); + initializer = builder_.getF32FloatAttr(constValue.value_or(0)); + } else if (type_cast(ty)) { + auto constValue = interp.toDouble(*value); + initializer = builder_.getF64FloatAttr(constValue.value_or(0)); + } + } + } + + if (!initializer) { + // default initialize to zero + initializer = builder_.getZeroAttr(varType); + } + + auto var = mlir::cxx::GlobalOp::create(builder_, loc, varType, false, name, + initializer); + + globalOps_.insert_or_assign(variableSymbol, var); + + return var; +} + auto Codegen::getLocation(SourceLocation location) -> mlir::Location { auto [filename, line, column] = unit_->tokenStartPosition(location); diff --git a/src/mlir/cxx/mlir/codegen.h b/src/mlir/cxx/mlir/codegen.h index e68f2d2f..5ccbfb36 100644 --- a/src/mlir/cxx/mlir/codegen.h +++ b/src/mlir/cxx/mlir/codegen.h @@ -262,6 +262,9 @@ class Codegen { [[nodiscard]] auto findOrCreateFunction(FunctionSymbol* functionSymbol) -> mlir::cxx::FuncOp; + [[nodiscard]] auto findOrCreateGlobal(VariableSymbol* var) + -> mlir::cxx::GlobalOp; + [[nodiscard]] auto newTemp(const Type* type, SourceLocation loc) -> mlir::cxx::AllocaOp; @@ -328,6 +331,7 @@ class Codegen { std::unordered_map classNames_; std::unordered_map locals_; std::unordered_map funcOps_; + std::unordered_map globalOps_; std::unordered_map uniqueSymbolNames_; std::unordered_map stringLiterals_; Loop loop_; diff --git a/src/mlir/cxx/mlir/codegen_declarations.cc b/src/mlir/cxx/mlir/codegen_declarations.cc index b092ba63..d8a8d303 100644 --- a/src/mlir/cxx/mlir/codegen_declarations.cc +++ b/src/mlir/cxx/mlir/codegen_declarations.cc @@ -182,6 +182,20 @@ auto Codegen::DeclarationVisitor::operator()(SimpleDeclarationAST* ast) -> DeclarationResult { if (!gen.function_) { // skip for now, as we only look for local variable declarations + + for (auto node : ListView{ast->initDeclaratorList}) { + auto var = symbol_cast(node->symbol); + if (!var) continue; + + auto glo = gen.findOrCreateGlobal(var); + if (!glo) { + gen.unit_->error(node->initializer->firstSourceLocation(), + std::format("cannot create global variable '{}'", + to_string(var->name()))); + continue; + } + } + return {}; } diff --git a/src/mlir/cxx/mlir/codegen_expressions.cc b/src/mlir/cxx/mlir/codegen_expressions.cc index 484181da..b6a15c02 100644 --- a/src/mlir/cxx/mlir/codegen_expressions.cc +++ b/src/mlir/cxx/mlir/codegen_expressions.cc @@ -376,6 +376,23 @@ auto Codegen::ExpressionVisitor::operator()(IdExpressionAST* ast) } } + if (auto var = symbol_cast(ast->symbol)) { + if (auto it = gen.globalOps_.find(var); it != gen.globalOps_.end()) { + auto loc = gen.getLocation(ast->firstSourceLocation()); + + auto ptrToVoidType = mlir::cxx::PointerType::get( + gen.builder_.getContext(), gen.convertType(var->type())); + + auto resultType = + mlir::cxx::PointerType::get(gen.builder_.getContext(), ptrToVoidType); + + auto op = mlir::cxx::AddressOfOp::create(gen.builder_, loc, resultType, + it->second.getSymName()); + + return {op}; + } + } + auto op = gen.emitTodoExpr(ast->firstSourceLocation(), to_string(ast->kind())); @@ -694,9 +711,7 @@ auto Codegen::ExpressionVisitor::operator()(PostIncrExpressionAST* ast) if (control()->is_integral_or_unscoped_enum(ast->baseExpression->type)) { auto loc = gen.getLocation(ast->firstSourceLocation()); - auto ptrTy = - mlir::cast(expressionResult.value.getType()); - auto elementTy = ptrTy.getElementType(); + auto elementTy = gen.convertType(ast->baseExpression->type); auto loadOp = mlir::cxx::LoadOp::create(gen.builder_, loc, elementTy, expressionResult.value); auto resultTy = gen.convertType(ast->baseExpression->type); @@ -944,8 +959,6 @@ auto Codegen::ExpressionVisitor::operator()(UnaryExpressionAST* ast) } if (control()->is_floating_point(ast->type)) { - resultType.dump(); - mlir::FloatAttr value; switch (ast->type->kind()) { case TypeKind::kFloat: diff --git a/src/mlir/cxx/mlir/codegen_units.cc b/src/mlir/cxx/mlir/codegen_units.cc index c3488312..5de5bb6d 100644 --- a/src/mlir/cxx/mlir/codegen_units.cc +++ b/src/mlir/cxx/mlir/codegen_units.cc @@ -70,6 +70,11 @@ struct Codegen::UnitVisitor { void operator()(NamespaceSymbol* symbol) { for (auto member : views::members(symbol)) { + if (auto var = symbol_cast(member)) { + p.gen.findOrCreateGlobal(var); + continue; + } + visit(*this, member); } } diff --git a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc index a74753fe..e2d4e899 100644 --- a/src/mlir/cxx/mlir/cxx_dialect_conversions.cc +++ b/src/mlir/cxx/mlir/cxx_dialect_conversions.cc @@ -130,7 +130,7 @@ class GlobalOpLowering : public OpConversionPattern { rewriter.replaceOpWithNewOp( op, elementType, op.getConstant(), LLVM::linkage::Linkage::Private, - op.getSymName(), adaptor.getValue().value()); + op.getSymName(), adaptor.getValueAttr()); return success(); } diff --git a/src/parser/cxx/lexer.cc b/src/parser/cxx/lexer.cc index f09d9b2b..cefaf7ad 100644 --- a/src/parser/cxx/lexer.cc +++ b/src/parser/cxx/lexer.cc @@ -159,6 +159,10 @@ auto Lexer::readToken() -> TokenKind { (LA(1) == '+' || LA(1) == '-')) { consume(2); integer_literal = false; + } else if (pos_ + 1 < end_ && (ch == 'e' || ch == 'E') && + std::isdigit(LA(1))) { + consume(1); + integer_literal = false; } else if (pos_ + 1 < end_ && ch == '\'' && is_idcont(LA(1))) { consume(); } else if (is_idcont(ch)) { diff --git a/src/parser/cxx/parser.cc b/src/parser/cxx/parser.cc index 1b910c38..e6e936d7 100644 --- a/src/parser/cxx/parser.cc +++ b/src/parser/cxx/parser.cc @@ -5659,17 +5659,12 @@ auto Parser::parse_init_declarator(InitDeclaratorAST*& yyast, control()->remove_cv(var->type())); } - if (var->isConstexpr()) { - if (!var->initializer()) { - parse_error(var->location(), "constexpr variable must be initialized"); - } else { - var->setConstValue(evaluate_constant_expression(var->initializer())); + if (var->initializer()) { + var->setConstValue(evaluate_constant_expression(var->initializer())); + } - if (!var->constValue().has_value()) { - type_error(var->location(), - "initializer of constexpr variable is not a constant"); - } - } + if (var->isConstexpr() && !var->constValue().has_value()) { + parse_error(var->location(), "constexpr variable must be initialized"); } }