diff --git a/.gitignore b/.gitignore index c4ed460..d3af9fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /node_modules +/thirdparty # ===== BEGIN JEKYLL ===== diff --git a/compiler/src/graphalg/parse/Parser.cpp b/compiler/src/graphalg/parse/Parser.cpp index f85d778..052e895 100644 --- a/compiler/src/graphalg/parse/Parser.cpp +++ b/compiler/src/graphalg/parse/Parser.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include "graphalg/SemiringTypes.h" #include "graphalg/parse/Lexer.h" #include "graphalg/parse/Parser.h" +#include "llvm/ADT/StringMap.h" namespace graphalg { @@ -88,6 +90,9 @@ class Parser { DimMapper _dimMapper; + // Expected return type for the current function. + mlir::Type _expectedReturnType; + Token cur() { return _tokens[_offset]; } void eat() { @@ -172,9 +177,53 @@ class Parser { mlir::ParseResult parseBinaryOp(BinaryOp &op); + /** + * Check input types and build \c MatMulOp. + * + * @return \c nullptr if type checking fails. + */ mlir::Value buildMatMul(mlir::Location loc, mlir::Value lhs, mlir::Value rhs); + /** + * Check input types and build \c ElementWiseOp. + * + * @param mustBeScalar Whether we parsed (a b) or (a (.) b). + * + * @return \c nullptr if type checking fails. + */ + mlir::Value buildElementWise(mlir::Location loc, mlir::Value lhs, BinaryOp op, + mlir::Value rhs, bool mustBeScalar); + + /** + * Check input types and build \c ElementWiseApplyOp. + * + * @return \c nullptr if type checking fails. + */ + mlir::Value buildElementWiseApply(mlir::Location loc, mlir::Value lhs, + mlir::func::FuncOp funcOp, mlir::Value rhs); + + /** + * Builds \c TransposeOp, \c NrowsOp, \c NcolsOp or \c NvalsOp depending on \c + * property. + */ + mlir::Value buildDotProperty(mlir::Location loc, mlir::Value value, + llvm::StringRef property); + mlir::ParseResult parseAtom(mlir::Value &v); + mlir::ParseResult parseAtomMatrix(mlir::Value &v); + mlir::ParseResult parseAtomVector(mlir::Value &v); + mlir::ParseResult parseAtomCast(mlir::Value &v); + mlir::ParseResult parseAtomZero(mlir::Value &v); + mlir::ParseResult parseAtomOne(mlir::Value &v); + mlir::ParseResult parseAtomApply(mlir::Value &v); + mlir::ParseResult parseAtomSelect(mlir::Value &v); + mlir::ParseResult parseAtomReduceRows(mlir::Value &v); + mlir::ParseResult parseAtomReduceCols(mlir::Value &v); + mlir::ParseResult parseAtomReduce(mlir::Value &v); + mlir::ParseResult parseAtomPickAny(mlir::Value &v); + mlir::ParseResult parseAtomDiag(mlir::Value &v); + mlir::ParseResult parseAtomTril(mlir::Value &v); + mlir::ParseResult parseAtomTriu(mlir::Value &v); mlir::ParseResult parseLiteral(mlir::Type ring, mlir::Value &v); @@ -254,7 +303,9 @@ void TypeFormatter::formatColumnVector(MatrixType t) { } void TypeFormatter::formatMatrix(MatrixType t) { - if (t.isColumnVector()) { + if (t.isScalar()) { + return formatScalar(t.getSemiring()); + } else if (t.isColumnVector()) { return formatColumnVector(t); } @@ -439,6 +490,11 @@ mlir::ParseResult Parser::parseProgram() { return mlir::success(); } +static bool hasReturn(mlir::Block &block) { + return block.mightHaveTerminator() && + llvm::isa(block.getTerminator()); +} + mlir::ParseResult Parser::parseFunction() { llvm::StringRef name; llvm::SmallVector paramNames; @@ -452,7 +508,12 @@ mlir::ParseResult Parser::parseFunction() { return mlir::failure(); } - // TODO: Check duplicate definition + if (auto *previousDef = _module.lookupSymbol(name)) { + auto diag = mlir::emitError(loc) + << "duplicate definition of function '" << name << "'"; + diag.attachNote(previousDef->getLoc()) << "original definition here"; + return diag; + } // Create the new op. auto funcType = _builder.getFunctionType(paramTypes, {returnType}); @@ -471,11 +532,17 @@ mlir::ParseResult Parser::parseFunction() { } } + // Set expected return type for this function + _expectedReturnType = returnType; + if (mlir::failed(parseBlock())) { return mlir::failure(); } - // TODO: Check for return statement. + // Check for return statement + if (!hasReturn(entryBlock)) { + return mlir::emitError(loc) << "function must have a return statement"; + } return mlir::success(); } @@ -496,11 +563,15 @@ Parser::parseParams(llvm::SmallVectorImpl &names, // First parameter auto &name = names.emplace_back(); auto &type = types.emplace_back(); - locs.emplace_back(cur().loc); + auto loc = cur().loc; + locs.emplace_back(loc); if (parseIdent(name) || eatOrError(Token::COLON) || parseType(type)) { return mlir::failure(); } + llvm::SmallDenseMap previousParams = { + {name, loc}}; + while (cur().type != Token::RPAREN) { // More parameters if (eatOrError(Token::COMMA)) { @@ -509,10 +580,21 @@ Parser::parseParams(llvm::SmallVectorImpl &names, auto &name = names.emplace_back(); auto &type = types.emplace_back(); - locs.emplace_back(cur().loc); + auto loc = cur().loc; + locs.emplace_back(loc); if (parseIdent(name) || eatOrError(Token::COLON) || parseType(type)) { return mlir::failure(); } + + // Check for duplicate parameter names. + if (previousParams.contains(name)) { + auto diag = mlir::emitError(loc) + << "duplicate parameter name '" << name << "'"; + diag.attachNote(previousParams.at(name)) << "previous definition here"; + return diag; + } + + previousParams.insert({name, loc}); } return eatOrError(Token::RPAREN); @@ -527,6 +609,13 @@ mlir::ParseResult Parser::parseBlock() { if (parseStmt()) { return mlir::failure(); } + + // Check if there are statements after a return + if (hasReturn(*_builder.getInsertionBlock()) && + cur().type != Token::RBRACE) { + return mlir::emitError(cur().loc) + << "statement after return is not allowed"; + } } if (eatOrError(Token::RBRACE)) { @@ -702,7 +791,7 @@ mlir::ParseResult Parser::parseStmtFor() { } } - // Parse condtion expression. + // Parse condition expression. mlir::Value result; loc = cur().loc; if (parseExpr(result) || eatOrError(Token::SEMI)) { @@ -733,12 +822,33 @@ mlir::ParseResult Parser::parseStmtFor() { mlir::ParseResult Parser::parseStmtReturn() { auto loc = cur().loc; + + // Check if return is inside a loop + auto *parentOp = _builder.getInsertionBlock()->getParentOp(); + if (llvm::isa(parentOp)) { + return mlir::emitError(loc) + << "return statement inside a loop is not allowed"; + } + + // The parser should not create any other nested ops apart from for loops, so + // we should be at the top-level of a function scope. + assert(llvm::isa(parentOp) && + "return outside of function body"); + mlir::Value returnValue; if (eatOrError(Token::RETURN) || parseExpr(returnValue) || eatOrError(Token::SEMI)) { return mlir::failure(); } + // Check return type matches + if (returnValue.getType() != _expectedReturnType) { + return mlir::emitError(loc) + << "return type mismatch: expected " + << typeToString(_expectedReturnType) << ", but got " + << typeToString(returnValue.getType()); + } + _builder.create(loc, returnValue); return mlir::success(); } @@ -923,6 +1033,11 @@ mlir::Value Parser::applyFill(mlir::Location baseLoc, mlir::Value base, << "vector fill [:] used with non-vector base"; diag.attachNote(baseLoc) << "base has type " << typeToString(baseType); return nullptr; + } else if (fill == ParsedFill::MATRIX && baseType.isColumnVector()) { + auto diag = mlir::emitError(fillLoc) + << "matrix fill [:, :] used with column vector base"; + diag.attachNote(baseLoc) << "base has type " << typeToString(baseType); + return nullptr; } return _builder.create(fillLoc, baseType, expr); @@ -942,7 +1057,7 @@ mlir::ParseResult Parser::parseStmtAccum(mlir::Location baseLoc, return mlir::emitError(baseLoc) << "type of base does not match the expression to accumulate: (" << typeToString(baseValue.getType()) << " vs. " - << typeToString(expr.getType()); + << typeToString(expr.getType()) << ")"; } // Rewrite a += b; to a = a (.+) b; @@ -960,11 +1075,30 @@ mlir::ParseResult Parser::parseRange(ParsedRange &r) { if (cur().type == Token::COLON) { // Const range + auto beginLoc = exprLoc; r.begin = expr; + + auto endLoc = cur().loc; if (eatOrError(Token::COLON) || parseExpr(r.end)) { return mlir::failure(); } + // Check that begin is an integer scalar + auto intScalarType = + MatrixType::scalarOf(SemiringTypes::forInt(_builder.getContext())); + if (r.begin.getType() != intScalarType) { + return mlir::emitError(beginLoc) + << "loop range start must be an integer, but got " + << typeToString(r.begin.getType()); + } + + // Check that end is an integer scalar + if (r.end.getType() != intScalarType) { + return mlir::emitError(endLoc) + << "loop range end must be an integer, but got " + << typeToString(r.end.getType()); + } + return mlir::success(); } else { r.dim = inferDim(expr, exprLoc); @@ -1023,21 +1157,14 @@ mlir::ParseResult Parser::parseExpr(mlir::Value &v, int minPrec) { << "expected matrix property such as 'nrows'"; } - if (cur().body == "T") { - atomLhs = _builder.create(cur().loc, atomLhs); - } else if (cur().body == "nrows") { - auto matType = llvm::cast(atomLhs.getType()); - atomLhs = _builder.create(cur().loc, matType.getRows()); - } else if (cur().body == "ncols") { - auto matType = llvm::cast(atomLhs.getType()); - atomLhs = _builder.create(cur().loc, matType.getCols()); - } else if (cur().body == "nvals") { - atomLhs = _builder.create(cur().loc, atomLhs); - } else { - return mlir::emitError(cur().loc) << "invalid matrix property"; - } - + auto loc = cur().loc; + auto property = cur().body; eat(); + + atomLhs = buildDotProperty(loc, atomLhs, property); + if (!atomLhs) { + return mlir::failure(); + } } else if (ewise) { eat(); // '(' eat(); // '.' @@ -1051,8 +1178,10 @@ mlir::ParseResult Parser::parseExpr(mlir::Value &v, int minPrec) { return mlir::failure(); } - atomLhs = - _builder.create(loc, funcOp, atomLhs, atomRhs); + atomLhs = buildElementWiseApply(loc, atomLhs, funcOp, atomRhs); + if (!atomLhs) { + return mlir::failure(); + } } else { // element-wise binop auto loc = cur().loc; @@ -1063,7 +1192,11 @@ mlir::ParseResult Parser::parseExpr(mlir::Value &v, int minPrec) { return mlir::failure(); } - atomLhs = _builder.create(loc, atomLhs, binop, atomRhs); + atomLhs = buildElementWise(loc, atomLhs, binop, atomRhs, + /*mustBeScalar=*/false); + if (!atomLhs) { + return mlir::failure(); + } } } else { // Binary operator @@ -1077,9 +1210,15 @@ mlir::ParseResult Parser::parseExpr(mlir::Value &v, int minPrec) { if (binop == BinaryOp::MUL) { // Matmul special case atomLhs = buildMatMul(loc, atomLhs, atomRhs); + if (!atomLhs) { + return mlir::failure(); + } } else { - // TODO: check scalar matrix types - atomLhs = _builder.create(loc, atomLhs, binop, atomRhs); + atomLhs = buildElementWise(loc, atomLhs, binop, atomRhs, + /*mustBeScalar=*/true); + if (!atomLhs) { + return mlir::failure(); + } } } } @@ -1159,6 +1298,174 @@ mlir::Value Parser::buildMatMul(mlir::Location loc, mlir::Value lhs, return nullptr; } +mlir::Value Parser::buildElementWise(mlir::Location loc, mlir::Value lhs, + BinaryOp op, mlir::Value rhs, + bool mustBeScalar) { + auto lhsType = llvm::cast(lhs.getType()); + auto rhsType = llvm::cast(rhs.getType()); + + // Check that semirings match + if (lhsType.getSemiring() != rhsType.getSemiring()) { + auto diag = mlir::emitError(loc) << "operands have different semirings"; + diag.attachNote(lhs.getLoc()) + << "left operand has semiring " << typeToString(lhsType.getSemiring()); + diag.attachNote(rhs.getLoc()) + << "right operand has semiring " << typeToString(rhsType.getSemiring()); + return nullptr; + } + + if (mustBeScalar) { + // Syntax a + b instead of a (.+) b, so require operands to be scalars. + if (!lhsType.isScalar() || !rhsType.isScalar()) { + auto diag = mlir::emitError(loc) + << "operands are not scalar. Did you mean to use " + "element-wise (a (.f) b) syntax?"; + diag.attachNote(lhs.getLoc()) + << "left operand has dimensions " << dimsToString(lhsType.getDims()); + diag.attachNote(rhs.getLoc()) + << "right operand has dimensions " << dimsToString(rhsType.getDims()); + return nullptr; + } + } else { + // Dimensions must match + if (lhsType.getDims() != rhsType.getDims()) { + auto diag = mlir::emitError(loc) << "operands have different dimensions"; + diag.attachNote(lhs.getLoc()) + << "left operand has dimensions " << dimsToString(lhsType.getDims()); + diag.attachNote(rhs.getLoc()) + << "right operand has dimensions " << dimsToString(rhsType.getDims()); + return nullptr; + } + } + + // Additional validation for specific operators + if (op == BinaryOp::SUB) { + // Subtraction only supports int and real semirings + auto *ctx = _builder.getContext(); + auto semiring = lhsType.getSemiring(); + if (semiring != SemiringTypes::forInt(ctx) && + semiring != SemiringTypes::forReal(ctx)) { + auto diag = mlir::emitError(loc) + << "subtraction is only supported for int and real semirings"; + diag.attachNote(lhs.getLoc()) + << "operands have semiring " << typeToString(semiring); + return nullptr; + } + } else if (op == BinaryOp::DIV) { + // Division only supports real semiring + auto *ctx = _builder.getContext(); + auto semiring = lhsType.getSemiring(); + if (semiring != SemiringTypes::forReal(ctx)) { + auto diag = mlir::emitError(loc) + << "division is only supported for real semiring"; + diag.attachNote(lhs.getLoc()) + << "operands have semiring " << typeToString(semiring); + return nullptr; + } + } else if (op == BinaryOp::LT || op == BinaryOp::GT || op == BinaryOp::LE || + op == BinaryOp::GE) { + // Ordered compare only supports int, real semirings + auto *ctx = _builder.getContext(); + auto semiring = lhsType.getSemiring(); + if (semiring != SemiringTypes::forInt(ctx) && + semiring != SemiringTypes::forReal(ctx)) { + auto diag = mlir::emitError(loc) << "ordered compare is only supported " + "for int and real semirings"; + diag.attachNote(lhs.getLoc()) + << "operands have semiring " << typeToString(semiring); + return nullptr; + } + } + + return _builder.create(loc, lhs, op, rhs); +} + +mlir::Value Parser::buildElementWiseApply(mlir::Location loc, mlir::Value lhs, + mlir::func::FuncOp funcOp, + mlir::Value rhs) { + // Validate element-wise function application + auto funcType = funcOp.getFunctionType(); + + // Check that function takes exactly 2 parameters + if (funcType.getNumInputs() != 2) { + auto diag = mlir::emitError(loc) + << "element-wise function application requires a function " + "with 2 parameters, but got " + << funcType.getNumInputs(); + diag.attachNote(funcOp.getLoc()) << "function defined here"; + return nullptr; + } + + auto lhsType = llvm::cast(lhs.getType()); + auto rhsType = llvm::cast(rhs.getType()); + auto param0Type = llvm::cast(funcType.getInput(0)); + auto param1Type = llvm::cast(funcType.getInput(1)); + + // Check that function parameters are scalars. + if (!param0Type.isScalar() || !param1Type.isScalar()) { + auto diag = mlir::emitError(loc) + << "element-wise function application requires function " + "parameters to be scalars"; + diag.attachNote(funcOp.getLoc()) + << "first parameter has type " << typeToString(param0Type); + diag.attachNote(funcOp.getLoc()) + << "second parameter has type " << typeToString(param1Type); + return nullptr; + } + + // Check that operand dimensions match + if (lhsType.getDims() != rhsType.getDims()) { + auto diag = mlir::emitError(loc) << "operands have different dimensions"; + diag.attachNote(lhs.getLoc()) + << "left operand has dimensions " << dimsToString(lhsType.getDims()); + diag.attachNote(rhs.getLoc()) + << "right operand has dimensions " << dimsToString(rhsType.getDims()); + return nullptr; + } + + // Check that operand semirings match function parameter semirings + if (lhsType.getSemiring() != param0Type.getSemiring()) { + auto diag = mlir::emitError(loc) + << "left operand semiring does not match first parameter type"; + diag.attachNote(lhs.getLoc()) + << "left operand has semiring " << typeToString(lhsType.getSemiring()); + diag.attachNote(funcOp.getLoc()) << "first parameter has semiring " + << typeToString(param0Type.getSemiring()); + return nullptr; + } + + if (rhsType.getSemiring() != param1Type.getSemiring()) { + auto diag = + mlir::emitError(loc) + << "right operand semiring does not match second parameter type"; + diag.attachNote(rhs.getLoc()) + << "right operand has semiring " << typeToString(rhsType.getSemiring()); + diag.attachNote(funcOp.getLoc()) << "second parameter has semiring " + << typeToString(param1Type.getSemiring()); + return nullptr; + } + + return _builder.create(loc, funcOp, lhs, rhs); +} + +mlir::Value Parser::buildDotProperty(mlir::Location loc, mlir::Value value, + llvm::StringRef property) { + if (property == "T") { + return _builder.create(loc, value); + } else if (property == "nrows") { + auto matType = llvm::cast(value.getType()); + return _builder.create(loc, matType.getRows()); + } else if (property == "ncols") { + auto matType = llvm::cast(value.getType()); + return _builder.create(loc, matType.getCols()); + } else if (property == "nvals") { + return _builder.create(loc, value); + } else { + mlir::emitError(loc) << "invalid matrix property"; + return nullptr; + } +} + mlir::ParseResult Parser::parseAtom(mlir::Value &v) { auto loc = cur().loc; switch (cur().type) { @@ -1189,250 +1496,61 @@ mlir::ParseResult Parser::parseAtom(mlir::Value &v) { } if (name == "Matrix") { - mlir::Type ring; - mlir::Value rowsExpr; - mlir::Value colsExpr; - if (eatOrError(Token::LANGLE) || parseSemiring(ring) || - eatOrError(Token::RANGLE) || eatOrError(Token::LPAREN) || - parseExpr(rowsExpr) || eatOrError(Token::COMMA) || - parseExpr(colsExpr) || eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - // TODO: Better ref locs - auto rows = inferDim(rowsExpr, loc); - auto cols = inferDim(colsExpr, loc); - if (!rows || !cols) { - return mlir::failure(); - } - - v = _builder.create( - loc, _builder.getType(rows, cols, ring), - llvm::cast(ring).addIdentity()); - return mlir::success(); + return parseAtomMatrix(v); } if (name == "Vector") { - mlir::Type ring; - mlir::Value rowsExpr; - if (eatOrError(Token::LANGLE) || parseSemiring(ring) || - eatOrError(Token::RANGLE) || eatOrError(Token::LPAREN) || - parseExpr(rowsExpr) || eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - // TODO: Better ref locs - auto rows = inferDim(rowsExpr, loc); - if (!rows) { - return mlir::failure(); - } - - auto *ctx = _builder.getContext(); - v = _builder.create( - loc, MatrixType::get(ctx, rows, DimAttr::getOne(ctx), ring), - llvm::cast(ring).addIdentity()); - return mlir::success(); + return parseAtomVector(v); } if (name == "cast") { - mlir::Type ring; - mlir::Value expr; - if (eatOrError(Token::LANGLE) || parseSemiring(ring) || - eatOrError(Token::RANGLE) || eatOrError(Token::LPAREN) || - parseExpr(expr) || eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - auto exprType = llvm::cast(expr.getType()); - auto *dialect = - _builder.getContext()->getLoadedDialect(); - if (!dialect->isCastLegal(exprType.getSemiring(), ring)) { - return mlir::emitError(loc) - << "invalid cast from " << typeToString(exprType.getSemiring()) - << " to " << typeToString(ring); - } - - v = _builder.create( - loc, - _builder.getType(exprType.getRows(), exprType.getCols(), - ring), - expr); - return mlir::success(); + return parseAtomCast(v); } if (name == "zero") { - mlir::Type ring; - if (eatOrError(Token::LPAREN) || parseSemiring(ring) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - auto value = llvm::cast(ring).addIdentity(); - v = _builder.create(loc, value); - return mlir::success(); + return parseAtomZero(v); } if (name == "one") { - mlir::Type ring; - if (eatOrError(Token::LPAREN) || parseSemiring(ring) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - auto value = llvm::cast(ring).mulIdentity(); - v = _builder.create(loc, value); - return mlir::success(); + return parseAtomOne(v); } if (name == "apply") { - mlir::func::FuncOp func; - llvm::SmallVector args(1); - if (eatOrError(Token::LPAREN) || parseFuncRef(func) || - eatOrError(Token::COMMA) || parseExpr(args[0])) { - return mlir::failure(); - } - - if (cur().type == Token::COMMA) { - // Have a second arg. - auto &arg = args.emplace_back(); - if (eatOrError(Token::COMMA) || parseExpr(arg)) { - return mlir::failure(); - } - } - - if (eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - if (args.size() == 1) { - v = _builder.create(loc, func, args[0]); - } else { - assert(args.size() == 2); - v = _builder.create(loc, func, args[0], args[1]); - } - - return mlir::success(); + return parseAtomApply(v); } if (name == "select") { - mlir::func::FuncOp func; - llvm::SmallVector args(1); - if (eatOrError(Token::LPAREN) || parseFuncRef(func) || - eatOrError(Token::COMMA) || parseExpr(args[0])) { - return mlir::failure(); - } - - if (cur().type == Token::COMMA) { - // Have a second arg. - auto &arg = args.emplace_back(); - if (eatOrError(Token::COMMA) || parseExpr(arg)) { - return mlir::failure(); - } - } - - if (eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - if (args.size() == 1) { - v = _builder.create(loc, func.getSymName(), args[0]); - } else { - assert(args.size() == 2); - v = _builder.create(loc, func.getSymName(), args[0], - args[1]); - } - - return mlir::success(); + return parseAtomSelect(v); } if (name == "reduceRows") { - mlir::Value arg; - if (eatOrError(Token::LPAREN) || parseExpr(arg) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - auto inputType = llvm::cast(arg.getType()); - auto *ctx = _builder.getContext(); - auto resultType = - MatrixType::get(ctx, inputType.getRows(), DimAttr::getOne(ctx), - inputType.getSemiring()); - v = _builder.create(loc, resultType, arg); - return mlir::success(); + return parseAtomReduceRows(v); } if (name == "reduceCols") { - mlir::Value arg; - if (eatOrError(Token::LPAREN) || parseExpr(arg) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - auto inputType = llvm::cast(arg.getType()); - auto *ctx = _builder.getContext(); - auto resultType = - MatrixType::get(ctx, DimAttr::getOne(ctx), inputType.getCols(), - inputType.getSemiring()); - v = _builder.create(loc, resultType, arg); - return mlir::success(); + return parseAtomReduceCols(v); } if (name == "reduce") { - mlir::Value arg; - if (eatOrError(Token::LPAREN) || parseExpr(arg) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - auto inputType = llvm::cast(arg.getType()); - v = _builder.create(loc, inputType.asScalar(), arg); - return mlir::success(); + return parseAtomReduce(v); } if (name == "pickAny") { - mlir::Value arg; - if (eatOrError(Token::LPAREN) || parseExpr(arg) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - v = _builder.create(loc, arg); - return mlir::success(); + return parseAtomPickAny(v); } if (name == "diag") { - mlir::Value arg; - if (eatOrError(Token::LPAREN) || parseExpr(arg) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - v = _builder.create(loc, arg); - return mlir::success(); + return parseAtomDiag(v); } // TODO: Make a separate extension if (name == "tril") { - mlir::Value arg; - if (eatOrError(Token::LPAREN) || parseExpr(arg) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - v = _builder.create(loc, arg); - return mlir::success(); + return parseAtomTril(v); } // TODO: Make a separate extension if (name == "triu") { - mlir::Value arg; - if (eatOrError(Token::LPAREN) || parseExpr(arg) || - eatOrError(Token::RPAREN)) { - return mlir::failure(); - } - - v = _builder.create(loc, arg); - return mlir::success(); + return parseAtomTriu(v); } auto var = _symbolTable.lookup(name); @@ -1448,6 +1566,18 @@ mlir::ParseResult Parser::parseAtom(mlir::Value &v) { return mlir::failure(); } + // Check that NOT is only used with bool semiring + auto vType = llvm::cast(v.getType()); + auto semiring = vType.getSemiring(); + auto *ctx = _builder.getContext(); + if (semiring != SemiringTypes::forBool(ctx)) { + auto diag = mlir::emitError(loc) + << "not operator is only supported for bool type"; + diag.attachNote(v.getLoc()) + << "operand has semiring " << typeToString(semiring); + return mlir::failure(); + } + v = _builder.create(loc, v); return mlir::success(); } @@ -1456,6 +1586,19 @@ mlir::ParseResult Parser::parseAtom(mlir::Value &v) { return mlir::failure(); } + // Check that negation is only used with int or real semirings + auto vType = llvm::cast(v.getType()); + auto semiring = vType.getSemiring(); + auto *ctx = _builder.getContext(); + if (semiring != SemiringTypes::forInt(ctx) && + semiring != SemiringTypes::forReal(ctx)) { + auto diag = mlir::emitError(loc) + << "negation is only supported for int and real types"; + diag.attachNote(v.getLoc()) + << "operand has semiring " << typeToString(semiring); + return mlir::failure(); + } + v = _builder.create(loc, v); return mlir::success(); } @@ -1494,6 +1637,383 @@ static std::optional parseFloat(llvm::StringRef s) { return v; } +mlir::ParseResult Parser::parseAtomMatrix(mlir::Value &v) { + auto loc = cur().loc; + mlir::Type ring; + mlir::Value rowsExpr; + mlir::Value colsExpr; + if (eatOrError(Token::LANGLE) || parseSemiring(ring) || + eatOrError(Token::RANGLE) || eatOrError(Token::LPAREN) || + parseExpr(rowsExpr) || eatOrError(Token::COMMA) || parseExpr(colsExpr) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto rows = inferDim(rowsExpr, loc); + auto cols = inferDim(colsExpr, loc); + if (!rows || !cols) { + return mlir::failure(); + } + + v = _builder.create( + loc, _builder.getType(rows, cols, ring), + llvm::cast(ring).addIdentity()); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomVector(mlir::Value &v) { + auto loc = cur().loc; + mlir::Type ring; + mlir::Value rowsExpr; + if (eatOrError(Token::LANGLE) || parseSemiring(ring) || + eatOrError(Token::RANGLE) || eatOrError(Token::LPAREN) || + parseExpr(rowsExpr) || eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto rows = inferDim(rowsExpr, loc); + if (!rows) { + return mlir::failure(); + } + + auto *ctx = _builder.getContext(); + v = _builder.create( + loc, MatrixType::get(ctx, rows, DimAttr::getOne(ctx), ring), + llvm::cast(ring).addIdentity()); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomCast(mlir::Value &v) { + auto loc = cur().loc; + mlir::Type ring; + mlir::Value expr; + if (eatOrError(Token::LANGLE) || parseSemiring(ring) || + eatOrError(Token::RANGLE) || eatOrError(Token::LPAREN) || + parseExpr(expr) || eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto exprType = llvm::cast(expr.getType()); + auto *dialect = _builder.getContext()->getLoadedDialect(); + if (!dialect->isCastLegal(exprType.getSemiring(), ring)) { + return mlir::emitError(loc) + << "invalid cast from " << typeToString(exprType.getSemiring()) + << " to " << typeToString(ring); + } + + v = _builder.create(loc, + _builder.getType( + exprType.getRows(), exprType.getCols(), ring), + expr); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomZero(mlir::Value &v) { + auto loc = cur().loc; + mlir::Type ring; + if (eatOrError(Token::LPAREN) || parseSemiring(ring) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto value = llvm::cast(ring).addIdentity(); + v = _builder.create(loc, value); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomOne(mlir::Value &v) { + auto loc = cur().loc; + mlir::Type ring; + if (eatOrError(Token::LPAREN) || parseSemiring(ring) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto value = llvm::cast(ring).mulIdentity(); + v = _builder.create(loc, value); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomApply(mlir::Value &v) { + auto loc = cur().loc; + mlir::func::FuncOp func; + llvm::SmallVector args(1); + if (eatOrError(Token::LPAREN) || parseFuncRef(func) || + eatOrError(Token::COMMA) || parseExpr(args[0])) { + return mlir::failure(); + } + + if (cur().type == Token::COMMA) { + // Have a second arg. + auto &arg = args.emplace_back(); + if (eatOrError(Token::COMMA) || parseExpr(arg)) { + return mlir::failure(); + } + } + + if (eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + // Validate function signature + auto funcType = func.getFunctionType(); + size_t numFuncArgs = funcType.getNumInputs(); + + if (args.size() == 1) { + // Unary apply: function must have exactly 1 argument + if (numFuncArgs != 1) { + auto diag = mlir::emitError(loc) + << "apply() with 1 matrix argument requires a function " + "with 1 parameter, but got " + << numFuncArgs; + diag.attachNote(func.getLoc()) << "function defined here"; + return mlir::failure(); + } + } else { + // Binary apply: function must have exactly 2 arguments + if (numFuncArgs != 2) { + auto diag = mlir::emitError(loc) + << "apply() with 2 arguments requires a function with 2 " + "parameters, but got " + << numFuncArgs; + diag.attachNote(func.getLoc()) << "function defined here"; + return mlir::failure(); + } + + // Second argument must be a scalar + auto arg1Type = llvm::cast(args[1].getType()); + if (!arg1Type.isScalar()) { + auto diag = mlir::emitError(loc) + << "second argument to apply() must be a scalar"; + diag.attachNote(args[1].getLoc()) + << "argument has type " << typeToString(arg1Type); + return mlir::failure(); + } + } + + // Check that all function parameters are scalars + for (size_t i = 0; i < numFuncArgs; i++) { + auto paramType = llvm::dyn_cast(funcType.getInput(i)); + if (!paramType || !paramType.isScalar()) { + auto diag = mlir::emitError(loc) + << "apply() requires function parameters to be scalars"; + diag.attachNote(func.getLoc()) << "parameter " << i << " has type " + << typeToString(funcType.getInput(i)); + return mlir::failure(); + } + } + + if (args.size() == 1) { + v = _builder.create(loc, func, args[0]); + } else { + assert(args.size() == 2); + v = _builder.create(loc, func, args[0], args[1]); + } + + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomSelect(mlir::Value &v) { + auto loc = cur().loc; + mlir::func::FuncOp func; + llvm::SmallVector args(1); + if (eatOrError(Token::LPAREN) || parseFuncRef(func) || + eatOrError(Token::COMMA) || parseExpr(args[0])) { + return mlir::failure(); + } + + if (cur().type == Token::COMMA) { + // Have a second arg. + auto &arg = args.emplace_back(); + if (eatOrError(Token::COMMA) || parseExpr(arg)) { + return mlir::failure(); + } + } + + if (eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + // Validate function signature + auto funcType = func.getFunctionType(); + size_t numFuncArgs = funcType.getNumInputs(); + + if (args.size() == 1) { + // Unary select: function must have exactly 1 argument + if (numFuncArgs != 1) { + auto diag = mlir::emitError(loc) + << "select() with 1 matrix argument requires a function " + "with 1 parameter, but got " + << numFuncArgs; + diag.attachNote(func.getLoc()) << "function defined here"; + return mlir::failure(); + } + } else { + // Binary select: function must have exactly 2 arguments + if (numFuncArgs != 2) { + auto diag = mlir::emitError(loc) + << "select() with 2 arguments requires a function with 2 " + "parameters, but got " + << numFuncArgs; + diag.attachNote(func.getLoc()) << "function defined here"; + return mlir::failure(); + } + + // Second argument must be a scalar + auto arg1Type = llvm::cast(args[1].getType()); + if (!arg1Type.isScalar()) { + auto diag = mlir::emitError(loc) + << "second argument to select() must be a scalar"; + diag.attachNote(args[1].getLoc()) + << "argument has type " << typeToString(arg1Type); + return mlir::failure(); + } + } + + // Check that all function parameters are scalars + for (size_t i = 0; i < numFuncArgs; i++) { + auto paramType = llvm::dyn_cast(funcType.getInput(i)); + if (!paramType || !paramType.isScalar()) { + auto diag = mlir::emitError(loc) + << "select() requires function parameters to be scalars"; + diag.attachNote(func.getLoc()) << "parameter " << i << " has type " + << typeToString(funcType.getInput(i)); + return mlir::failure(); + } + } + + // Check that function returns a boolean scalar + if (funcType.getNumResults() != 1) { + auto diag = mlir::emitError(loc) + << "select() requires function to return exactly one value"; + diag.attachNote(func.getLoc()) << "function defined here"; + return mlir::failure(); + } + + auto returnType = llvm::dyn_cast(funcType.getResult(0)); + auto *ctx = _builder.getContext(); + auto expectedReturnType = MatrixType::scalarOf(SemiringTypes::forBool(ctx)); + if (!returnType || returnType != expectedReturnType) { + auto diag = mlir::emitError(loc) + << "select() requires function to return bool"; + diag.attachNote(func.getLoc()) + << "function returns " << typeToString(funcType.getResult(0)); + return mlir::failure(); + } + + if (args.size() == 1) { + v = _builder.create(loc, func.getSymName(), args[0]); + } else { + assert(args.size() == 2); + v = _builder.create(loc, func.getSymName(), args[0], + args[1]); + } + + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomReduceRows(mlir::Value &v) { + auto loc = cur().loc; + mlir::Value arg; + if (eatOrError(Token::LPAREN) || parseExpr(arg) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto inputType = llvm::cast(arg.getType()); + auto *ctx = _builder.getContext(); + auto resultType = MatrixType::get( + ctx, inputType.getRows(), DimAttr::getOne(ctx), inputType.getSemiring()); + v = _builder.create(loc, resultType, arg); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomReduceCols(mlir::Value &v) { + auto loc = cur().loc; + mlir::Value arg; + if (eatOrError(Token::LPAREN) || parseExpr(arg) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto inputType = llvm::cast(arg.getType()); + auto *ctx = _builder.getContext(); + auto resultType = MatrixType::get( + ctx, DimAttr::getOne(ctx), inputType.getCols(), inputType.getSemiring()); + v = _builder.create(loc, resultType, arg); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomReduce(mlir::Value &v) { + auto loc = cur().loc; + mlir::Value arg; + if (eatOrError(Token::LPAREN) || parseExpr(arg) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto inputType = llvm::cast(arg.getType()); + v = _builder.create(loc, inputType.asScalar(), arg); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomPickAny(mlir::Value &v) { + auto loc = cur().loc; + mlir::Value arg; + if (eatOrError(Token::LPAREN) || parseExpr(arg) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + v = _builder.create(loc, arg); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomDiag(mlir::Value &v) { + auto loc = cur().loc; + mlir::Value arg; + if (eatOrError(Token::LPAREN) || parseExpr(arg) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + auto argType = llvm::cast(arg.getType()); + if (!argType.isColumnVector() && !argType.isRowVector()) { + auto diag = mlir::emitError(loc) + << "diag() requires a row or column vector"; + diag.attachNote(arg.getLoc()) + << "argument has type " << typeToString(argType); + return mlir::failure(); + } + + v = _builder.create(loc, arg); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomTril(mlir::Value &v) { + auto loc = cur().loc; + mlir::Value arg; + if (eatOrError(Token::LPAREN) || parseExpr(arg) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + v = _builder.create(loc, arg); + return mlir::success(); +} + +mlir::ParseResult Parser::parseAtomTriu(mlir::Value &v) { + auto loc = cur().loc; + mlir::Value arg; + if (eatOrError(Token::LPAREN) || parseExpr(arg) || + eatOrError(Token::RPAREN)) { + return mlir::failure(); + } + + v = _builder.create(loc, arg); + return mlir::success(); +} + mlir::ParseResult Parser::parseLiteral(mlir::Type ring, mlir::Value &v) { auto *ctx = _builder.getContext(); mlir::TypedAttr attr; diff --git a/compiler/test/parse-err/accum-type-mismatch.gr b/compiler/test/parse-err/accum-type-mismatch.gr new file mode 100644 index 0000000..b4a5e66 --- /dev/null +++ b/compiler/test/parse-err/accum-type-mismatch.gr @@ -0,0 +1,8 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func AccumTypeMismatch() -> int { + a = int(42); + // expected-error@below{{type of base does not match the expression to accumulate: (int vs. real}} + a += real(3.14); + return int(0); +} diff --git a/compiler/test/parse-err/accum-undefined.gr b/compiler/test/parse-err/accum-undefined.gr new file mode 100644 index 0000000..1e603f8 --- /dev/null +++ b/compiler/test/parse-err/accum-undefined.gr @@ -0,0 +1,7 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func AccumUndefined() -> int { + // expected-error@below{{undefined variable}} + a += int(42); + return int(0); +} diff --git a/compiler/test/parse-err/apply-binary-second-arg-not-scalar.gr b/compiler/test/parse-err/apply-binary-second-arg-not-scalar.gr new file mode 100644 index 0000000..fd6e998 --- /dev/null +++ b/compiler/test/parse-err/apply-binary-second-arg-not-scalar.gr @@ -0,0 +1,14 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func binary(a: int, b: int) -> int { + return a + b; +} + +func ApplyBinarySecondArgNotScalar( + a: Matrix, + // expected-note@below{{argument has type Vector}} + b: Vector) -> Matrix { + // Binary apply requires the second argument to be a scalar + // expected-error@below{{second argument to apply() must be a scalar}} + return apply(binary, a, b); +} diff --git a/compiler/test/parse-err/apply-func-non-scalar-args.gr b/compiler/test/parse-err/apply-func-non-scalar-args.gr new file mode 100644 index 0000000..1d1c1f6 --- /dev/null +++ b/compiler/test/parse-err/apply-func-non-scalar-args.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{parameter 0 has type Matrix}} +func nonScalarFunc(m: Matrix) -> int { + return int(0); +} + +func ApplyFuncNonScalarArgs(m: Matrix) -> Matrix { + // apply() requires the function to have scalar arguments + // expected-error@below{{apply() requires function parameters to be scalars}} + return apply(nonScalarFunc, m); +} diff --git a/compiler/test/parse-err/apply-unary-func-wrong-arg-count.gr b/compiler/test/parse-err/apply-unary-func-wrong-arg-count.gr new file mode 100644 index 0000000..cdb32cc --- /dev/null +++ b/compiler/test/parse-err/apply-unary-func-wrong-arg-count.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{function defined here}} +func twoArgs(a: int, b: int) -> int { + return a + b; +} + +func ApplyUnaryFuncWrongArgCount(m: Matrix) -> Matrix { + // Unary apply requires the function to have exactly 1 argument + // expected-error@below{{apply() with 1 matrix argument requires a function with 1 parameter, but got 2}} + return apply(twoArgs, m); +} diff --git a/compiler/test/parse-err/apply-undefined-func.gr b/compiler/test/parse-err/apply-undefined-func.gr new file mode 100644 index 0000000..697f55f --- /dev/null +++ b/compiler/test/parse-err/apply-undefined-func.gr @@ -0,0 +1,7 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func ApplyUndefinedFunc(m: Matrix) -> Matrix { + // The function 'undefinedFunc' does not exist + // expected-error@below{{unknown function 'undefinedFunc'}} + return apply(undefinedFunc, m); +} diff --git a/compiler/test/parse-err/cmp-bool-unsupported.gr b/compiler/test/parse-err/cmp-bool-unsupported.gr new file mode 100644 index 0000000..00a2cf2 --- /dev/null +++ b/compiler/test/parse-err/cmp-bool-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func CmpBoolUnsupported( + // expected-note@below{{operands have semiring bool}} + a: bool, b: bool) -> bool { + // Comparison is not supported for bool semiring (only int, real, trop_real) + // expected-error@below{{ordered compare is only supported for int and real semirings}} + return a < b; +} diff --git a/compiler/test/parse-err/cmp-trop-int-unsupported.gr b/compiler/test/parse-err/cmp-trop-int-unsupported.gr new file mode 100644 index 0000000..9bbc9ff --- /dev/null +++ b/compiler/test/parse-err/cmp-trop-int-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func CmpTropIntUnsupported( + // expected-note@below{{operands have semiring trop_int}} + a: trop_int, b: trop_int) -> bool { + // Comparison is not supported for trop_int semiring (only int, real, trop_real) + // expected-error@below{{ordered compare is only supported for int and real semirings}} + return a < b; +} diff --git a/compiler/test/parse-err/diag-not-vector.gr b/compiler/test/parse-err/diag-not-vector.gr new file mode 100644 index 0000000..81297cc --- /dev/null +++ b/compiler/test/parse-err/diag-not-vector.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func DiagNotVector( + // expected-note@below{{argument has type Matrix}} + m: Matrix) -> Matrix { + // diag() requires a row or column vector, but m is a general matrix + // expected-error@below{{diag() requires a row or column vector}} + return diag(m); +} diff --git a/compiler/test/parse-err/div-bool-unsupported.gr b/compiler/test/parse-err/div-bool-unsupported.gr new file mode 100644 index 0000000..5e60aa6 --- /dev/null +++ b/compiler/test/parse-err/div-bool-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func DivBoolUnsupported( + // expected-note@below{{operands have semiring bool}} + a: bool, b: bool) -> bool { + // Division is not supported for bool semiring (only real and trop_int) + // expected-error@below{{division is only supported for real semiring}} + return a / b; +} diff --git a/compiler/test/parse-err/div-int-unsupported.gr b/compiler/test/parse-err/div-int-unsupported.gr new file mode 100644 index 0000000..277c8a4 --- /dev/null +++ b/compiler/test/parse-err/div-int-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func DivIntUnsupported( + // expected-note@below{{operands have semiring int}} + a: int, b: int) -> int { + // Division is not supported for int semiring (only real and trop_int) + // expected-error@below{{division is only supported for real semiring}} + return a / b; +} diff --git a/compiler/test/parse-err/div-trop-real-unsupported.gr b/compiler/test/parse-err/div-trop-real-unsupported.gr new file mode 100644 index 0000000..4673f1c --- /dev/null +++ b/compiler/test/parse-err/div-trop-real-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func DivTropRealUnsupported( + // expected-note@below{{operands have semiring trop_real}} + a: trop_real, b: trop_real) -> trop_real { + // Division is not supported for trop_real semiring (only real and trop_int) + // expected-error@below{{division is only supported for real semiring}} + return a / b; +} diff --git a/compiler/test/parse-err/ewise-different-dimensions.gr b/compiler/test/parse-err/ewise-different-dimensions.gr new file mode 100644 index 0000000..061cb21 --- /dev/null +++ b/compiler/test/parse-err/ewise-different-dimensions.gr @@ -0,0 +1,11 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func EwiseDifferentDimensions( + // expected-note@below{{left operand has dimensions (s x s)}} + a: Matrix, + // expected-note@below{{right operand has dimensions (t x t)}} + b: Matrix) -> Matrix { + // Element-wise operations require operands to have the same dimensions + // expected-error@below{{operands have different dimensions}} + return a (.+) b; +} diff --git a/compiler/test/parse-err/ewise-different-semirings.gr b/compiler/test/parse-err/ewise-different-semirings.gr new file mode 100644 index 0000000..3e4a31c --- /dev/null +++ b/compiler/test/parse-err/ewise-different-semirings.gr @@ -0,0 +1,11 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func EwiseDifferentSemirings( + // expected-note@below{{left operand has semiring int}} + a: Matrix, + // expected-note@below{{right operand has semiring real}} + b: Matrix) -> Matrix { + // Element-wise operations require operands to have the same semiring + // expected-error@below{{operands have different semirings}} + return a (.+) b; +} diff --git a/compiler/test/parse-err/ewise-func-different-dimensions.gr b/compiler/test/parse-err/ewise-func-different-dimensions.gr new file mode 100644 index 0000000..fa6c2ec --- /dev/null +++ b/compiler/test/parse-err/ewise-func-different-dimensions.gr @@ -0,0 +1,14 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func addInts(a: int, b: int) -> int { + return a + b; +} + +func EwiseFuncDifferentDimensions( + // expected-note@below{{left operand has dimensions (s x s)}} + m: Matrix, + // expected-note@below{{right operand has dimensions (t x t)}} + n: Matrix) -> Matrix { + // expected-error@below{{operands have different dimensions}} + return m (.addInts) n; +} diff --git a/compiler/test/parse-err/ewise-func-non-scalar-params.gr b/compiler/test/parse-err/ewise-func-non-scalar-params.gr new file mode 100644 index 0000000..7e54b38 --- /dev/null +++ b/compiler/test/parse-err/ewise-func-non-scalar-params.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{first parameter has type Matrix}} +// expected-note@below{{second parameter has type int}} +func matrixParam(m: Matrix, n: int) -> int { + return n; +} + +func EwiseFuncNonScalarParams(m: Matrix, n: Matrix) -> Matrix { + // expected-error@below{{element-wise function application requires function parameters to be scalars}} + return m (.matrixParam) n; +} diff --git a/compiler/test/parse-err/ewise-func-type-mismatch-left.gr b/compiler/test/parse-err/ewise-func-type-mismatch-left.gr new file mode 100644 index 0000000..f8d36c4 --- /dev/null +++ b/compiler/test/parse-err/ewise-func-type-mismatch-left.gr @@ -0,0 +1,13 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{first parameter has semiring real}} +func addReals(a: real, b: real) -> real { + return a + b; +} + +func EwiseFuncTypeMismatchLeft( + // expected-note@below{{left operand has semiring int}} + m: Matrix, n: Matrix) -> Matrix { + // expected-error@below{{left operand semiring does not match first parameter type}} + return m (.addReals) n; +} diff --git a/compiler/test/parse-err/ewise-func-type-mismatch-right.gr b/compiler/test/parse-err/ewise-func-type-mismatch-right.gr new file mode 100644 index 0000000..8ff3258 --- /dev/null +++ b/compiler/test/parse-err/ewise-func-type-mismatch-right.gr @@ -0,0 +1,14 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{second parameter has semiring real}} +func addMixed(a: int, b: real) -> real { + return cast(a) + b; +} + +func EwiseFuncTypeMismatchRight( + m: Matrix, + // expected-note@below{{right operand has semiring int}} + n: Matrix) -> Matrix { + // expected-error@below{{right operand semiring does not match second parameter type}} + return m (.addMixed) n; +} diff --git a/compiler/test/parse-err/ewise-func-undefined.gr b/compiler/test/parse-err/ewise-func-undefined.gr new file mode 100644 index 0000000..9cbaa34 --- /dev/null +++ b/compiler/test/parse-err/ewise-func-undefined.gr @@ -0,0 +1,6 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func EwiseFuncUndefined(m: Matrix, n: Matrix) -> Matrix { + // expected-error@below{{unknown function 'doesNotExist'}} + return m (.doesNotExist) n; +} diff --git a/compiler/test/parse-err/ewise-func-wrong-param-count.gr b/compiler/test/parse-err/ewise-func-wrong-param-count.gr new file mode 100644 index 0000000..7abe4dc --- /dev/null +++ b/compiler/test/parse-err/ewise-func-wrong-param-count.gr @@ -0,0 +1,11 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{function defined here}} +func oneParam(a: int) -> int { + return a; +} + +func EwiseFuncWrongParamCount(m: Matrix, n: Matrix) -> Matrix { + // expected-error@below{{element-wise function application requires a function with 2 parameters, but got 1}} + return m (.oneParam) n; +} diff --git a/compiler/test/parse-err/func-name-dup.gr b/compiler/test/parse-err/func-name-dup.gr new file mode 100644 index 0000000..ad6e1e8 --- /dev/null +++ b/compiler/test/parse-err/func-name-dup.gr @@ -0,0 +1,11 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{original definition here}} +func Dup(a: int) -> int { + return a; +} + +// expected-error@below{{duplicate definition of function 'Dup'}} +func Dup(a: int) -> int { + return a; +} diff --git a/compiler/test/parse-err/func-param-dup.gr b/compiler/test/parse-err/func-param-dup.gr new file mode 100644 index 0000000..ca6bf1c --- /dev/null +++ b/compiler/test/parse-err/func-param-dup.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func Dup( + // expected-note@below{{previous definition here}} + a: int, + // expected-error@below{{duplicate parameter name 'a'}} + a: int) -> int { + return a; +} diff --git a/compiler/test/parse-err/literal-bool-from-int.gr b/compiler/test/parse-err/literal-bool-from-int.gr new file mode 100644 index 0000000..7bbc15b --- /dev/null +++ b/compiler/test/parse-err/literal-bool-from-int.gr @@ -0,0 +1,6 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func LiteralBoolFromInt() -> bool { + // expected-error@below{{expected 'true' or 'false'}} + return bool(42); +} diff --git a/compiler/test/parse-err/literal-int-from-real.gr b/compiler/test/parse-err/literal-int-from-real.gr new file mode 100644 index 0000000..1d50003 --- /dev/null +++ b/compiler/test/parse-err/literal-int-from-real.gr @@ -0,0 +1,6 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func LiteralIntFromReal() -> int { + // expected-error@below{{expected an integer value}} + return int(42.0); +} diff --git a/compiler/test/parse-err/literal-trop-real-from-bool.gr b/compiler/test/parse-err/literal-trop-real-from-bool.gr new file mode 100644 index 0000000..7b9239a --- /dev/null +++ b/compiler/test/parse-err/literal-trop-real-from-bool.gr @@ -0,0 +1,6 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func LiteralTropRealFromBool() -> trop_real { + // expected-error@below{{expected a floating-point value}} + return trop_real(false); +} diff --git a/compiler/test/parse-err/loop-range-end-non-int.gr b/compiler/test/parse-err/loop-range-end-non-int.gr new file mode 100644 index 0000000..60b0aa5 --- /dev/null +++ b/compiler/test/parse-err/loop-range-end-non-int.gr @@ -0,0 +1,10 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func LoopRangeEndNonInt() -> int { + a = int(0); + // expected-error@below{{loop range end must be an integer, but got real}} + for i in int(1):real(10.0) { + a = a + int(1); + } + return a; +} diff --git a/compiler/test/parse-err/loop-range-not-dimension.gr b/compiler/test/parse-err/loop-range-not-dimension.gr new file mode 100644 index 0000000..1a3d52d --- /dev/null +++ b/compiler/test/parse-err/loop-range-not-dimension.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func LoopRangeNotDimension(m: Matrix) -> int { + a = int(0); + // m.nvals is an integer, not a dimension + // expected-error@below{{not a dimension type}} + // expected-note@below{{defined here}} + for i in m.nvals { + a = a + int(1); + } + return a; +} diff --git a/compiler/test/parse-err/loop-range-start-non-int.gr b/compiler/test/parse-err/loop-range-start-non-int.gr new file mode 100644 index 0000000..be83ae3 --- /dev/null +++ b/compiler/test/parse-err/loop-range-start-non-int.gr @@ -0,0 +1,10 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func LoopRangeStartNonInt() -> int { + a = int(0); + // expected-error@below{{loop range start must be an integer, but got real}} + for i in real(1.0):int(10) { + a = a + int(1); + } + return a; +} diff --git a/compiler/test/parse-err/loop-scope.gr b/compiler/test/parse-err/loop-scope.gr new file mode 100644 index 0000000..4635679 --- /dev/null +++ b/compiler/test/parse-err/loop-scope.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func LoopScope() -> int { + a = int(0); + for i in int(1):int(10) { + b = int(42); + a = a + b; + } + // Variable b should not be accessible outside the loop + // expected-error@below{{unrecognized variable}} + return b; +} diff --git a/compiler/test/parse-err/loop-until-non-bool.gr b/compiler/test/parse-err/loop-until-non-bool.gr new file mode 100644 index 0000000..84f01d2 --- /dev/null +++ b/compiler/test/parse-err/loop-until-non-bool.gr @@ -0,0 +1,10 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func LoopUntilNonBool() -> int { + a = int(0); + for i in int(1):int(10) { + a = a + int(1); + // expected-error@below{{loop condition does not produce a boolean scalar, got int}} + } until int(5); + return a; +} diff --git a/compiler/test/parse-err/mask-dimension-mismatch.gr b/compiler/test/parse-err/mask-dimension-mismatch.gr new file mode 100644 index 0000000..6a91b46 --- /dev/null +++ b/compiler/test/parse-err/mask-dimension-mismatch.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func MaskDimensionMismatch( + a: Matrix, + m: Matrix, + e: Matrix) -> Matrix { + // expected-error@below{{base dimensions do not match the dimensions of the mask}} + // expected-note@below{{base dimension: (s x s)}} + // expected-note@below{{mask dimensions: (t x t)}} + a = e; + return a; +} diff --git a/compiler/test/parse-err/matmul-dimension-mismatch.gr b/compiler/test/parse-err/matmul-dimension-mismatch.gr new file mode 100644 index 0000000..cc6a5c8 --- /dev/null +++ b/compiler/test/parse-err/matmul-dimension-mismatch.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func MatMulDimensionMismatch( + // expected-note@below{{left side has dimensions (r x s)}} + a: Matrix, + // expected-note@below{{right side has dimensions (t x u)}} + b: Matrix) -> Matrix { + // Left matrix has s columns, right matrix has t rows + // These dimensions must match for matrix multiplication + // expected-error@below{{incompatible dimensions for matrix multiply}} + return a * b; +} diff --git a/compiler/test/parse-err/matrix-fill-col-vector.gr b/compiler/test/parse-err/matrix-fill-col-vector.gr new file mode 100644 index 0000000..f5d6df8 --- /dev/null +++ b/compiler/test/parse-err/matrix-fill-col-vector.gr @@ -0,0 +1,8 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func MatrixFillColVector(v: Vector, x: int) -> Vector { + // expected-error@below{{matrix fill [:, :] used with column vector base}} + // expected-note@below{{base has type Vector}} + v[:, :] = x; + return v; +} diff --git a/compiler/test/parse-err/neg-bool-unsupported.gr b/compiler/test/parse-err/neg-bool-unsupported.gr new file mode 100644 index 0000000..7ec17e5 --- /dev/null +++ b/compiler/test/parse-err/neg-bool-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func NegBoolUnsupported( + // expected-note@below{{operand has semiring bool}} + a: bool) -> bool { + // Negation is not supported for bool semiring (only int and real) + // expected-error@below{{negation is only supported for int and real types}} + return -a; +} diff --git a/compiler/test/parse-err/neg-trop-int-unsupported.gr b/compiler/test/parse-err/neg-trop-int-unsupported.gr new file mode 100644 index 0000000..a841c6b --- /dev/null +++ b/compiler/test/parse-err/neg-trop-int-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func NegTropIntUnsupported( + // expected-note@below{{operand has semiring trop_int}} + a: trop_int) -> trop_int { + // Negation is not supported for tropical semirings (only int and real) + // expected-error@below{{negation is only supported for int and real types}} + return -a; +} diff --git a/compiler/test/parse-err/not-int-unsupported.gr b/compiler/test/parse-err/not-int-unsupported.gr new file mode 100644 index 0000000..740ee47 --- /dev/null +++ b/compiler/test/parse-err/not-int-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func NotIntUnsupported( + // expected-note@below{{operand has semiring int}} + a: int) -> int { + // NOT is not supported for int semiring (only bool) + // expected-error@below{{not operator is only supported for bool type}} + return !a; +} diff --git a/compiler/test/parse-err/not-real-unsupported.gr b/compiler/test/parse-err/not-real-unsupported.gr new file mode 100644 index 0000000..c8256e2 --- /dev/null +++ b/compiler/test/parse-err/not-real-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func NotRealUnsupported( + // expected-note@below{{operand has semiring real}} + a: real) -> real { + // NOT is not supported for real semiring (only bool) + // expected-error@below{{not operator is only supported for bool type}} + return !a; +} diff --git a/compiler/test/parse-err/not-trop-int-unsupported.gr b/compiler/test/parse-err/not-trop-int-unsupported.gr new file mode 100644 index 0000000..30fe961 --- /dev/null +++ b/compiler/test/parse-err/not-trop-int-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func NotTropIntUnsupported( + // expected-note@below{{operand has semiring trop_int}} + a: trop_int) -> trop_int { + // NOT is not supported for trop_int semiring (only bool) + // expected-error@below{{not operator is only supported for bool type}} + return !a; +} diff --git a/compiler/test/parse-err/reassign-type-mismatch.gr b/compiler/test/parse-err/reassign-type-mismatch.gr new file mode 100644 index 0000000..65dfc87 --- /dev/null +++ b/compiler/test/parse-err/reassign-type-mismatch.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func ReassignTypeMismatch() -> int { + // expected-note@below{{previous assigment was here}} + a = int(42); + // expected-error@below{{cannot assign value of type real to previously defined variable of type int}} + a = real(3.14); + return int(0); +} diff --git a/compiler/test/parse-err/return-in-loop.gr b/compiler/test/parse-err/return-in-loop.gr new file mode 100644 index 0000000..ddfb472 --- /dev/null +++ b/compiler/test/parse-err/return-in-loop.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func ReturnInLoop() -> int { + for i in int(1):int(10) { + // expected-error@below{{return statement inside a loop is not allowed}} + return int(5); + } + return int(0); +} diff --git a/compiler/test/parse-err/return-missing.gr b/compiler/test/parse-err/return-missing.gr new file mode 100644 index 0000000..d153997 --- /dev/null +++ b/compiler/test/parse-err/return-missing.gr @@ -0,0 +1,6 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-error@below{{function must have a return statement}} +func ReturnMissing() -> int { + a = int(42); +} diff --git a/compiler/test/parse-err/return-multiple.gr b/compiler/test/parse-err/return-multiple.gr new file mode 100644 index 0000000..68d7074 --- /dev/null +++ b/compiler/test/parse-err/return-multiple.gr @@ -0,0 +1,8 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func ReturnMultiple() -> int { + a = int(42); + return a; + // expected-error@below{{statement after return is not allowed}} + return int(5); +} diff --git a/compiler/test/parse-err/return-not-last.gr b/compiler/test/parse-err/return-not-last.gr new file mode 100644 index 0000000..22109dc --- /dev/null +++ b/compiler/test/parse-err/return-not-last.gr @@ -0,0 +1,8 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func ReturnNotLast() -> int { + return int(42); + // expected-error@below{{statement after return is not allowed}} + a = int(5); + return a; +} diff --git a/compiler/test/parse-err/return-wrong-type.gr b/compiler/test/parse-err/return-wrong-type.gr new file mode 100644 index 0000000..7638cb9 --- /dev/null +++ b/compiler/test/parse-err/return-wrong-type.gr @@ -0,0 +1,6 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func ReturnWrongType() -> int { + // expected-error@below{{return type mismatch: expected int, but got real}} + return real(3.14); +} diff --git a/compiler/test/parse-err/select-binary-func-non-bool-return.gr b/compiler/test/parse-err/select-binary-func-non-bool-return.gr new file mode 100644 index 0000000..36b929f --- /dev/null +++ b/compiler/test/parse-err/select-binary-func-non-bool-return.gr @@ -0,0 +1,14 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{function returns real}} +func returnsBinaryReal(a: int, b: int) -> real { + return real(3.14); +} + +func SelectBinaryFuncNonBoolReturn( + a: Matrix, + b: int) -> Matrix { + // Binary select also requires the function to return bool + // expected-error@below{{select() requires function to return bool}} + return select(returnsBinaryReal, a, b); +} diff --git a/compiler/test/parse-err/select-binary-second-arg-not-scalar.gr b/compiler/test/parse-err/select-binary-second-arg-not-scalar.gr new file mode 100644 index 0000000..57fb340 --- /dev/null +++ b/compiler/test/parse-err/select-binary-second-arg-not-scalar.gr @@ -0,0 +1,14 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func binary(a: int, b: int) -> bool { + return a == b; +} + +func SelectBinarySecondArgNotScalar( + a: Matrix, + // expected-note@below{{argument has type Vector}} + b: Vector) -> Matrix { + // Binary select requires the second argument to be a scalar + // expected-error@below{{second argument to select() must be a scalar}} + return select(binary, a, b); +} diff --git a/compiler/test/parse-err/select-func-non-bool-return.gr b/compiler/test/parse-err/select-func-non-bool-return.gr new file mode 100644 index 0000000..2d98147 --- /dev/null +++ b/compiler/test/parse-err/select-func-non-bool-return.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{function returns int}} +func returnsInt(a: int) -> int { + return a; +} + +func SelectFuncNonBoolReturn(m: Matrix) -> Matrix { + // select() requires the function to return bool + // expected-error@below{{select() requires function to return bool}} + return select(returnsInt, m); +} diff --git a/compiler/test/parse-err/select-func-non-scalar-args.gr b/compiler/test/parse-err/select-func-non-scalar-args.gr new file mode 100644 index 0000000..b18114e --- /dev/null +++ b/compiler/test/parse-err/select-func-non-scalar-args.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{parameter 0 has type Matrix}} +func nonScalarFunc(m: Matrix) -> bool { + return bool(true); +} + +func SelectFuncNonScalarArgs(m: Matrix) -> Matrix { + // select() requires the function to have scalar arguments + // expected-error@below{{select() requires function parameters to be scalars}} + return select(nonScalarFunc, m); +} diff --git a/compiler/test/parse-err/select-unary-func-wrong-arg-count.gr b/compiler/test/parse-err/select-unary-func-wrong-arg-count.gr new file mode 100644 index 0000000..0eb9905 --- /dev/null +++ b/compiler/test/parse-err/select-unary-func-wrong-arg-count.gr @@ -0,0 +1,12 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +// expected-note@below{{function defined here}} +func twoArgs(a: int, b: int) -> bool { + return a == b; +} + +func SelectUnaryFuncWrongArgCount(m: Matrix) -> Matrix { + // Unary select requires the function to have exactly 1 argument + // expected-error@below{{select() with 1 matrix argument requires a function with 1 parameter, but got 2}} + return select(twoArgs, m); +} diff --git a/compiler/test/parse-err/select-undefined-func.gr b/compiler/test/parse-err/select-undefined-func.gr new file mode 100644 index 0000000..9746575 --- /dev/null +++ b/compiler/test/parse-err/select-undefined-func.gr @@ -0,0 +1,7 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func SelectUndefinedFunc(m: Matrix) -> Matrix { + // The function 'undefinedFunc' does not exist + // expected-error@below{{unknown function 'undefinedFunc'}} + return select(undefinedFunc, m); +} diff --git a/compiler/test/parse-err/sub-bool-unsupported.gr b/compiler/test/parse-err/sub-bool-unsupported.gr new file mode 100644 index 0000000..109e520 --- /dev/null +++ b/compiler/test/parse-err/sub-bool-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func SubBoolUnsupported( + // expected-note@below{{operands have semiring bool}} + a: bool, b: bool) -> bool { + // Subtraction is not supported for bool semiring (only int and real) + // expected-error@below{{subtraction is only supported for int and real semirings}} + return a - b; +} diff --git a/compiler/test/parse-err/sub-matrix-not-scalar.gr b/compiler/test/parse-err/sub-matrix-not-scalar.gr new file mode 100644 index 0000000..b4415bb --- /dev/null +++ b/compiler/test/parse-err/sub-matrix-not-scalar.gr @@ -0,0 +1,11 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func SubMatrixNotScalar( + // expected-note@below{{left operand has dimensions (s x s)}} + a: Matrix, + // expected-note@below{{right operand has dimensions (s x s)}} + b: Matrix) -> Matrix { + // Subtraction only works on scalars, not matrices + // expected-error@below{{operands are not scalar. Did you mean to use element-wise (a (.f) b) syntax?}} + return a - b; +} diff --git a/compiler/test/parse-err/sub-trop-int-unsupported.gr b/compiler/test/parse-err/sub-trop-int-unsupported.gr new file mode 100644 index 0000000..192a9e1 --- /dev/null +++ b/compiler/test/parse-err/sub-trop-int-unsupported.gr @@ -0,0 +1,9 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func SubTropIntUnsupported( + // expected-note@below{{operands have semiring trop_int}} + a: trop_int, b: trop_int) -> trop_int { + // Subtraction is not supported for tropical semirings (only int and real) + // expected-error@below{{subtraction is only supported for int and real semirings}} + return a - b; +} diff --git a/compiler/test/parse-err/undefined-var.gr b/compiler/test/parse-err/undefined-var.gr new file mode 100644 index 0000000..b7931e8 --- /dev/null +++ b/compiler/test/parse-err/undefined-var.gr @@ -0,0 +1,6 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func UndefinedVar() -> int { + // expected-error@below{{unrecognized variable}} + return i_am_undefined; +} diff --git a/compiler/test/parse-err/vector-fill-matrix.gr b/compiler/test/parse-err/vector-fill-matrix.gr new file mode 100644 index 0000000..8f75de0 --- /dev/null +++ b/compiler/test/parse-err/vector-fill-matrix.gr @@ -0,0 +1,8 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func VectorFillMatrix(m: Matrix, x: int) -> Matrix { + // expected-error@below{{vector fill [:] used with non-vector base}} + // expected-note@below{{base has type Matrix}} + m[:] = x; + return m; +} diff --git a/compiler/test/parse-err/vector-fill-non-scalar.gr b/compiler/test/parse-err/vector-fill-non-scalar.gr new file mode 100644 index 0000000..dad881a --- /dev/null +++ b/compiler/test/parse-err/vector-fill-non-scalar.gr @@ -0,0 +1,7 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func VectorFillNonScalar(v: Vector, e: Vector) -> Vector { + // expected-error@below{{fill expression is not a scalar}} + v[:] = e; + return v; +} diff --git a/compiler/test/parse-err/vector-fill-row-vector.gr b/compiler/test/parse-err/vector-fill-row-vector.gr new file mode 100644 index 0000000..d2f365b --- /dev/null +++ b/compiler/test/parse-err/vector-fill-row-vector.gr @@ -0,0 +1,8 @@ +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func VectorFillRowVector(m: Matrix<1, s, int>, x: int) -> Matrix<1, s, int> { + // expected-error@below{{vector fill [:] used with non-vector base}} + // expected-note@below{{base has type Matrix<1, s, int>}} + m[:] = x; + return m; +} diff --git a/compiler/test/parse/arith.gr b/compiler/test/parse/arith.gr index e22fb30..cd5e8cd 100644 --- a/compiler/test/parse/arith.gr +++ b/compiler/test/parse/arith.gr @@ -22,7 +22,7 @@ func Mul(a:int, b:int, c:int) -> int { } // CHECK-LABEL: @Div -func Div(a:int, b:int, c:int) -> int { +func Div(a:real, b:real, c:real) -> real { // CHECK: %[[#DIV:]] = graphalg.ewise %arg0 DIV %arg1 // CHECK: return %[[#DIV]] return a / b; diff --git a/compiler/test/parse/diag-scalar.gr b/compiler/test/parse/diag-scalar.gr new file mode 100644 index 0000000..14e7328 --- /dev/null +++ b/compiler/test/parse/diag-scalar.gr @@ -0,0 +1,8 @@ +// RUN: graphalg-translate --import-graphalg < %s | FileCheck %s + +// Scalars (1x1 matrices) are both row and column vectors, so diag should accept them +// CHECK-LABEL: @DiagScalar +func DiagScalar(x: int) -> int { + // CHECK: graphalg.diag %arg0 + return diag(x); +} diff --git a/docs/parse-tests.md b/docs/parse-tests.md new file mode 100644 index 0000000..3eaeaab --- /dev/null +++ b/docs/parse-tests.md @@ -0,0 +1,191 @@ +# Parser Testing Guide + +This guide explains how to write and test parser error tests for the GraphAlg compiler. + +## Test Framework + +The GraphAlg compiler uses **LLVM's `lit` (LLVM Integrated Tester)** framework for testing. Tests are located in `compiler/test/`. + +### Test Types + +1. **Success tests** (`compiler/test/parse/`) - Verify correct parsing and MLIR output using FileCheck +2. **Error tests** (`compiler/test/parse-err/`) - Verify that invalid code produces expected error messages + +## Writing Parser Error Tests + +### Basic Structure + +```graphalg +// RUN: graphalg-translate --import-graphalg --verify-diagnostics %s + +func MyTest() -> int { + // expected-error@below{{error message}} + invalid code here; + return int(0); +} +``` + +### Key Components + +1. **RUN directive**: Specifies how to run the test + - `--verify-diagnostics` flag enables diagnostic verification + - `%s` is replaced with the test file path + +2. **Diagnostic annotations**: Tell the verifier what errors/notes to expect + - `// expected-error@below{...}` - expects an error on the next line + - `// expected-note@below{...}` - expects a note/additional info on the next line + - Messages in braces must match the actual diagnostic message exactly + +3. **Error location**: The verifier checks that errors appear at the expected location + +### Multiple Diagnostics + +When multiple diagnostics appear on the same statement, list them all: + +```graphalg +// expected-error@below{{base dimensions do not match the dimensions of the mask}} +// expected-note@below{{base dimension: (s x s)}} +// expected-note@below{{mask dimensions: (t x t)}} +a = e; +``` + +## Running Tests + +### Run All Tests + +```bash +cmake --build compiler/build --target check +``` + +### Run Individual Tests + +```bash +# First build the compiler +cmake --build compiler/build --target graphalg-translate + +# Run with diagnostic verification +./compiler/build/tools/graphalg-translate --import-graphalg --verify-diagnostics path/to/test.gr + +# Run without verification to see actual errors +./compiler/build/tools/graphalg-translate --import-graphalg path/to/test.gr +``` + +## Common Error Categories + +### Duplicate Definitions + +**Function names** (`func-name-dup.gr`): +```graphalg +// expected-note@below{{original definition here}} +func Dup(a: int) -> int { return a; } + +// expected-error@below{{duplicate definition of function 'Dup'}} +func Dup(a: int) -> int { return a; } +``` + +**Parameter names** (`func-param-dup.gr`): +```graphalg +func Dup( + // expected-note@below{{previous definition here}} + a: int, + // expected-error@below{{duplicate parameter name 'a'}} + a: int) -> int { return a; } +``` + +### Type Errors + +**Reassignment with different type** (`reassign-type-mismatch.gr`): +```graphalg +func Test() -> int { + // expected-note@below{{previous assigment was here}} + a = int(42); + // expected-error@below{{cannot assign value of type real to previously defined variable of type int}} + a = real(3.14); + return int(0); +} +``` + +**Accumulate type mismatch** (`accum-type-mismatch.gr`): +```graphalg +func Test() -> int { + a = int(42); + // expected-error@below{{type of base does not match the expression to accumulate: (int vs. real}} + a += real(3.14); + return int(0); +} +``` + +### Variable Scoping + +**Undefined variable** (`accum-undefined.gr`): +```graphalg +func Test() -> int { + // expected-error@below{{undefined variable}} + a += int(42); + return int(0); +} +``` + +**Loop scope** (`loop-scope.gr`): +```graphalg +func Test() -> int { + a = int(0); + for i in int(1):int(10) { + b = int(42); + a = a + b; + } + // Variable b is not accessible outside the loop + // expected-error@below{{unrecognized variable}} + return b; +} +``` + +## Type Formatting in Error Messages + +The parser formats types in a user-friendly way: +- Scalars: `int`, `real`, `bool`, `trop_int`, `trop_real` +- Vectors: `Vector` (column vector with dimension `s` and element type `int`) +- Matrices: `Matrix` (matrix with row dimension `r`, column dimension `c`, and element type `int`) + +## Tips for Writing Tests + +1. **Test one error at a time** - Keep tests focused on a single error condition +2. **Use descriptive function names** - Name the function after what it tests +3. **Add comments** - Explain what the test is checking +4. **Verify exact error messages** - The diagnostic message must match exactly +5. **Check location precision** - Ensure the error points to the right token +6. **Test boundary cases** - Cover edge cases like row vectors, column vectors, and full matrices + +## Adding Parser Error Checks + +When adding new error detection to the parser: + +1. **Identify the error condition** in the parser code +2. **Choose an appropriate error message** - Be clear and user-friendly +3. **Use `mlir::emitError(location)`** to report the error +4. **Add notes with `diag.attachNote()`** for additional context +5. **Write a test** that verifies the error is caught +6. **Run the test** to verify the exact error message format +7. **Update the test** with the correct expected message + +## Best Practices for Error Messages + +### Always Use TypeFormatter for Type Display + +When emitting error messages that include types, **always use `typeToString()`** instead of directly printing MLIR types. This ensures user-friendly error messages. + +**Correct:** +```cpp +auto diag = mlir::emitError(loc) + << "parameter has type " << typeToString(funcType.getInput(i)); +``` + +**Incorrect:** +```cpp +auto diag = mlir::emitError(loc) + << "parameter has type " << funcType.getInput(i); // Shows raw MLIR type +``` + +The `typeToString()` function uses `TypeFormatter` internally, which formats types in a user-friendly way: +- Raw MLIR: `!graphalg.mat x distinct[0]<> x i64>` +- Formatted: `Matrix`