From 4211a82f146993a429b86b3ba081644af19b853c Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Fri, 22 Mar 2024 11:48:26 +0100 Subject: [PATCH 1/9] Fat JAR that generates C code --- .gitignore | 2 + README.md | 17 ++++++++ build.sbt | 29 +++++++++++++- float-safe-optimizer/Main.scala | 40 +++++++++++++++++++ float-safe-optimizer/examples/.gitignore | 1 + float-safe-optimizer/examples/add3.rise | 3 ++ float-safe-optimizer/examples/add3Seq.rise | 3 ++ .../examples/add3TypeError.rise | 3 ++ project/plugins.sbt | 2 + 9 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 float-safe-optimizer/Main.scala create mode 100644 float-safe-optimizer/examples/.gitignore create mode 100644 float-safe-optimizer/examples/add3.rise create mode 100644 float-safe-optimizer/examples/add3Seq.rise create mode 100644 float-safe-optimizer/examples/add3TypeError.rise diff --git a/.gitignore b/.gitignore index 6a98303bb..4f6a8aaf9 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ modules.xml *.pdf *.gz *.sc + +float-safe-optimizer.jar diff --git a/README.md b/README.md index c80375b9f..a482d8623 100644 --- a/README.md +++ b/README.md @@ -26,3 +26,20 @@ The source code for the compiler is organised into sub-packages of the `shine` p ### Setup and Documentation Please have a look at: https://rise-lang.org/doc/ + +### Float Safe Optimizer + +This repository contains an optimizer executable that preserves floating-point semantics. +To build a Fat JAR executable: +```sh +sbt float_safe_optimizer/assembly +``` +To optimize a Rise program and generate code: +```sh +java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar $function_name $rise_source_path $output_path +``` +For example: +```sh +java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar add3 float-safe-optimizer/examples/add3Seq.rise float-safe-optimizer/examples/add3Seq.c +java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar add3 float-safe-optimizer/examples/add3.rise float-safe-optimizer/examples/add3.c +``` \ No newline at end of file diff --git a/build.sbt b/build.sbt index bd06b1d57..0c7695f9e 100644 --- a/build.sbt +++ b/build.sbt @@ -104,5 +104,32 @@ clap := { "echo y" #| (baseDirectory.value + "/lib/clap/buildClap.sh") ! } - +lazy val float_safe_optimizer = (project in file("float-safe-optimizer")) + .dependsOn(riseAndShine) + .enablePlugins(AssemblyPlugin) + .settings( + excludeDependencies ++= Seq( + ExclusionRule("org.scala-lang.modules", s"scala-xml_${scalaBinaryVersion.value}"), + ExclusionRule("junit", "junit"), + ExclusionRule("com.novocode", "junit-interface"), + ExclusionRule("org.scalacheck", "scalacheck"), + ExclusionRule("org.scalatest", "scalatest"), + ExclusionRule("com.lihaoyi", s"os-lib_${scalaBinaryVersion.value}"), + ExclusionRule("com.typesafe.play", s"play-json_${scalaBinaryVersion.value}"), + ExclusionRule("org.rise-lang", s"opencl-executor_${scalaBinaryVersion.value}"), + ExclusionRule("org.rise-lang", "CUexecutor"), + ExclusionRule("org.elevate-lang", s"cuda-executor_${scalaBinaryVersion.value}"), + ExclusionRule("org.elevate-lang", s"meta_${scalaBinaryVersion.value}"), + ), + name := "float-safe-optimizer", + javaOptions ++= Seq("-Xss20m", "-Xms512m", "-Xmx4G"), + assemblyOutputPath in assembly := file("float-safe-optimizer.jar"), + assemblyMergeStrategy in assembly := { + case PathList("fasterxml", xs @ _*) => + MergeStrategy.discard + case x => + val oldStrategy = (assemblyMergeStrategy in assembly).value + oldStrategy(x) + } + ) diff --git a/float-safe-optimizer/Main.scala b/float-safe-optimizer/Main.scala new file mode 100644 index 000000000..46089fa6d --- /dev/null +++ b/float-safe-optimizer/Main.scala @@ -0,0 +1,40 @@ +package float_safe_optimizer + +import util.gen +import rise.core.Expr +import rise.core.DSL.ToBeTyped + +object Main { + def main(args: Array[String]): Unit = { + val name = args(0) + val exprSourcePath = args(1) + val outputPath = args(2) + + val exprSource = util.readFile(exprSourcePath) + val untypedExpr = parseExpr(prefixImports(exprSource)) + val typedExpr = untypedExpr.toExpr + val code = gen.openmp.function.asStringFromExpr(typedExpr) + util.writeToPath(outputPath, code) + } + + def prefixImports(source: String): String = + s""" + |import rise.core.DSL._ + |import rise.core.DSL.Type._ + |import rise.core.DSL.HighLevelConstructs._ + |import rise.core.primitives._ + |import rise.core.types._ + |import rise.core.types.DataType._ + |import rise.openmp.DSL._ + |import rise.openmp.primitives._ + |$source + |""".stripMargin + + def parseExpr(source: String): ToBeTyped[Expr] = { + import scala.reflect.runtime.universe + import scala.tools.reflect.ToolBox + + val toolbox = universe.runtimeMirror(getClass.getClassLoader).mkToolBox() + toolbox.eval(toolbox.parse(source)).asInstanceOf[ToBeTyped[Expr]] + } +} diff --git a/float-safe-optimizer/examples/.gitignore b/float-safe-optimizer/examples/.gitignore new file mode 100644 index 000000000..09b2ac1d1 --- /dev/null +++ b/float-safe-optimizer/examples/.gitignore @@ -0,0 +1 @@ +*.c \ No newline at end of file diff --git a/float-safe-optimizer/examples/add3.rise b/float-safe-optimizer/examples/add3.rise new file mode 100644 index 000000000..d25c729c1 --- /dev/null +++ b/float-safe-optimizer/examples/add3.rise @@ -0,0 +1,3 @@ +depFun((n: Nat) => fun((n`.`i32) ->: (n`.`i32))(in => + map(add(li32(3)))(in) +)) \ No newline at end of file diff --git a/float-safe-optimizer/examples/add3Seq.rise b/float-safe-optimizer/examples/add3Seq.rise new file mode 100644 index 000000000..db7da18c7 --- /dev/null +++ b/float-safe-optimizer/examples/add3Seq.rise @@ -0,0 +1,3 @@ +depFun((n: Nat) => fun((n`.`i32) ->: (n`.`i32))(in => + mapSeq(add(li32(3)))(in) +)) \ No newline at end of file diff --git a/float-safe-optimizer/examples/add3TypeError.rise b/float-safe-optimizer/examples/add3TypeError.rise new file mode 100644 index 000000000..117121818 --- /dev/null +++ b/float-safe-optimizer/examples/add3TypeError.rise @@ -0,0 +1,3 @@ +depFun((n: Nat) => fun((n`.`i32) ->: (n`.`f32))(in => + map(add(li32(3)))(in) +)) \ No newline at end of file diff --git a/project/plugins.sbt b/project/plugins.sbt index 76d8825fd..6676856df 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,5 @@ addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.0-RC1") addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.3.0") addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.2.17") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.15.0") +addDependencyTreePlugin \ No newline at end of file From 26090e25a87b371574f50a27204666b5776ccea8 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Sun, 24 Mar 2024 13:51:09 +0100 Subject: [PATCH 2/9] simple lowering strategy for float safe optimizer --- build.sbt | 7 -- float-safe-optimizer/Main.scala | 4 +- float-safe-optimizer/Optimize.scala | 85 +++++++++++++++++++ src/main/scala/rise/eqsat/Analysis.scala | 13 +-- .../scala/rise/eqsat/LoweringSearch.scala | 30 ++++--- 5 files changed, 116 insertions(+), 23 deletions(-) create mode 100644 float-safe-optimizer/Optimize.scala diff --git a/build.sbt b/build.sbt index 0c7695f9e..389c1cfa2 100644 --- a/build.sbt +++ b/build.sbt @@ -124,12 +124,5 @@ lazy val float_safe_optimizer = (project in file("float-safe-optimizer")) name := "float-safe-optimizer", javaOptions ++= Seq("-Xss20m", "-Xms512m", "-Xmx4G"), assemblyOutputPath in assembly := file("float-safe-optimizer.jar"), - assemblyMergeStrategy in assembly := { - case PathList("fasterxml", xs @ _*) => - MergeStrategy.discard - case x => - val oldStrategy = (assemblyMergeStrategy in assembly).value - oldStrategy(x) - } ) diff --git a/float-safe-optimizer/Main.scala b/float-safe-optimizer/Main.scala index 46089fa6d..58e3fdd1c 100644 --- a/float-safe-optimizer/Main.scala +++ b/float-safe-optimizer/Main.scala @@ -13,7 +13,9 @@ object Main { val exprSource = util.readFile(exprSourcePath) val untypedExpr = parseExpr(prefixImports(exprSource)) val typedExpr = untypedExpr.toExpr - val code = gen.openmp.function.asStringFromExpr(typedExpr) + val optimizedExpr = Optimize(typedExpr) + println(optimizedExpr) + val code = gen.openmp.function.asStringFromExpr(optimizedExpr) util.writeToPath(outputPath, code) } diff --git a/float-safe-optimizer/Optimize.scala b/float-safe-optimizer/Optimize.scala new file mode 100644 index 000000000..947e8da57 --- /dev/null +++ b/float-safe-optimizer/Optimize.scala @@ -0,0 +1,85 @@ +package float_safe_optimizer + +import rise.eqsat._ + +object Optimize { + def apply(e: rise.core.Expr): rise.core.Expr = { + val expr = Expr.fromNamed(e) + val (body, annotation, wrapBody) = analyseTopLevel(expr) + + val rules = { + import rise.eqsat.rules._ + Seq( + mapFusion, + reduceSeq, + mapSeq, + mapSeqArray, + removeTransposePair, + fstReduction, + sndReduction, + /* maybe: + omp.mapPar --> need heuristic vs mapSeq + toMemAfterMapSeq + storeToMem + reduceSeqMapFusion + reduceSeqUnroll --> need heuristic + */ + ) + } + + LoweringSearch.init().run(BENF, Cost, Seq(body), rules, Some(annotation)) match { + case Some(resBody) => + val res = wrapBody(resBody) + Expr.toNamed(res) + case None => throw new Exception("could not find valid low-level expression") + } + } + + // TODO: this code might be avoidable by making DPIA+codegen rely on top-level type instead of top-level constructs + def analyseTopLevel(e: Expr) + : (Expr, (BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation]), Expr => Expr) = { + import rise.eqsat.RWAnnotationDSL._ + + // returns (body, argCount, wrapBody) + def rec(e: Expr): (Expr, Int, Expr => Expr) = { + e.node match { + case NatLambda(e2) => + val (b, a, w) = rec(e2) + (b, a, b => Expr(NatLambda(w(b)), e.t)) + case DataLambda(e2) => + val (b, a, w) = rec(e2) + (b, a, b => Expr(DataLambda(w(b)), e.t)) + case Lambda(e2) => + e.t.node match { + case FunType(Type(_: DataTypeNode[_, _]), _) => () + case _ => throw new Exception("top level higher-order functions are not supported") + } + val (b, a, w) = rec(e2) + (b, a + 1, b => Expr(Lambda(w(b)), e.t)) + case _ => + if (!e.t.node.isInstanceOf[DataTypeNode[_, _]]) { + throw new Exception("expected body with data type") + } + (e, 0, b => b) + } + } + + val (b, a, w) = rec(e) + (b, (write, List.tabulate(a)(i => i -> read).toMap), w) + } + + object Cost extends CostFunction[Int] { + val ordering = implicitly + + override def cost(egraph: EGraph, enode: ENode, t: TypeId, costs: EClassId => Int): Int = { + import rise.core.primitives._ + + val nodeCost = enode match { + // prefer avoiding mapSeq + case Primitive(mapSeq()) => 5 + case _ => 1 + } + enode.children().foldLeft(nodeCost) { case (acc, eclass) => acc + costs(eclass) } + } + } +} diff --git a/src/main/scala/rise/eqsat/Analysis.scala b/src/main/scala/rise/eqsat/Analysis.scala index 1360eca30..5bd4ba6a6 100644 --- a/src/main/scala/rise/eqsat/Analysis.scala +++ b/src/main/scala/rise/eqsat/Analysis.scala @@ -687,7 +687,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) case App(f, e) => val fInT = egraph(egraph.get(f).t) match { case FunType(inT, _) => inT - case _ => throw new Exception("this should not happen") + case _ => throw new Exception("app expected fun type") } val eT = egraph.get(e).t @@ -713,7 +713,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) } } } - case _ => throw new Exception("this should not happen") + case _ => throw new Exception("app expected fun type") } } newBeams @@ -748,7 +748,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) } annotation match { case NotDataTypeAnnotation(NatFunType(at)) => (at, env) -> newBeam - case _ => throw new Exception("this should not happen") + case _ => throw new Exception("natApp expected NatFunType") } } case DataApp(f, _) => @@ -762,7 +762,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) } annotation match { case NotDataTypeAnnotation(DataFunType(at)) => (at, env) -> newBeam - case _ => throw new Exception("this should not happen") + case _ => throw new Exception("dataApp expected DataFunType") } } case AddrApp(f, _) => @@ -776,7 +776,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) } annotation match { case NotDataTypeAnnotation(AddrFunType(at)) => (at, env) -> newBeam - case _ => throw new Exception("this should not happen") + case _ => throw new Exception("addrApp expected AddrFunType") } } case NatLambda(e) => @@ -923,6 +923,9 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) } } Seq(rec(n)) + case rp.id() => + // FIXME: only supports non-functional values + Seq(read ->: read, write ->: write) case _ => throw new Exception(s"did not expect $p") } val beam = Seq(( diff --git a/src/main/scala/rise/eqsat/LoweringSearch.scala b/src/main/scala/rise/eqsat/LoweringSearch.scala index a4aa8bb78..4a33064c9 100644 --- a/src/main/scala/rise/eqsat/LoweringSearch.scala +++ b/src/main/scala/rise/eqsat/LoweringSearch.scala @@ -8,27 +8,37 @@ object LoweringSearch { // TODO: enable giving a sketch, maybe merge with GuidedSearch? class LoweringSearch(var filter: Predicate) { - private def topLevelAnnotation(e: Expr): BeamExtractRW.TypeAnnotation = { + private def topLevelAnnotation(t: Type): BeamExtractRW.TypeAnnotation = { import RWAnnotationDSL._ - e.node match { - case NatLambda(e) => nFunT(topLevelAnnotation(e)) - case DataLambda(e) => dtFunT(topLevelAnnotation(e)) - case Lambda(e) => read ->: topLevelAnnotation(e) - case _ => - assert(e.t.node.isInstanceOf[DataTypeNode[_, _]]) + t.node match { + case NatFunType(t) => nFunT(topLevelAnnotation(t)) + case DataFunType(t) => dtFunT(topLevelAnnotation(t)) + case AddrFunType(t) => aFunT(topLevelAnnotation(t)) + case FunType(ta, tb) => + if (!ta.node.isInstanceOf[DataTypeNode[_, _]]) { + throw new Exception("top level higher-order functions are not supported") + } + read ->: topLevelAnnotation(tb) + case _: DataTypeNode[_, _] => write + case _ => + throw new Exception(s"did not expect type $t") } } def run(normalForm: NF, costFunction: CostFunction[_], startBeam: Seq[Expr], - loweringRules: Seq[Rewrite]): Option[Expr] = { + loweringRules: Seq[Rewrite], + annotations: Option[(BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation])] = None): Option[Expr] = { println("---- lowering") val egraph = EGraph.empty() val normBeam = startBeam.map(normalForm.normalize) - val expectedAnnotation = topLevelAnnotation(normBeam.head) + val expectedAnnotations = annotations match { + case Some(annotations) => annotations + case None => (topLevelAnnotation(normBeam.head.t), Map.empty[Int, BeamExtractRW.TypeAnnotation]) + } val rootId = normBeam.map(egraph.addExpr) .reduce[EClassId] { case (a, b) => egraph.union(a, b)._1 } @@ -43,7 +53,7 @@ class LoweringSearch(var filter: Predicate) { util.printTime("lowered extraction time", { val tmp = Analysis.oneShot(BeamExtractRW(1, costFunction), egraph)(egraph.find(rootId)) - tmp.get((expectedAnnotation, Map.empty)) + tmp.get(expectedAnnotations) .map { beam => ExprWithHashCons.expr(egraph)(beam.head._2) } }) } From bce46e4903b9cf26ec355c49e1290c4d87f7d42f Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Tue, 14 May 2024 11:40:50 +0200 Subject: [PATCH 3/9] addv example --- float-safe-optimizer/Optimize.scala | 19 ++++++++++++++++--- float-safe-optimizer/examples/addv.rise | 4 ++++ 2 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 float-safe-optimizer/examples/addv.rise diff --git a/float-safe-optimizer/Optimize.scala b/float-safe-optimizer/Optimize.scala index 947e8da57..4a0c7733e 100644 --- a/float-safe-optimizer/Optimize.scala +++ b/float-safe-optimizer/Optimize.scala @@ -10,19 +10,32 @@ object Optimize { val rules = { import rise.eqsat.rules._ Seq( - mapFusion, + // implementation choices: reduceSeq, mapSeq, + // satisfying read/write annotations: mapSeqArray, + // simplifications: + mapFusion, + reduceSeqMapFusion, removeTransposePair, fstReduction, sndReduction, /* maybe: omp.mapPar --> need heuristic vs mapSeq - toMemAfterMapSeq + toMemAfterMapSeq / storeToMem storeToMem reduceSeqMapFusion - reduceSeqUnroll --> need heuristic + mapSeqUnroll/reduceSeqUnroll --> need heuristic + eliminateMapIdentity + + is it worth the cost?: + betaExtract + betaNatExtract + eta + + not generic enough, use Elevate passes or custom applier?: + idxReduction_i_n */ ) } diff --git a/float-safe-optimizer/examples/addv.rise b/float-safe-optimizer/examples/addv.rise new file mode 100644 index 000000000..f4ac071ee --- /dev/null +++ b/float-safe-optimizer/examples/addv.rise @@ -0,0 +1,4 @@ +depFun((n: Nat) => depFun((m: Nat) => depFun((o: Nat) => + fun(((n+o)`.`i32) ->: ((m+o)`.`i32) ->: (o`.`i32))((a, b) => + zip(take(o)(a))(take(o)(b)) |> map(fun(x => fst(x) + snd(x))) +)))) \ No newline at end of file From 04373e9263c7dd4cfa68504aea4315504a726114 Mon Sep 17 00:00:00 2001 From: Simon Date: Sun, 2 Mar 2025 18:19:36 +0100 Subject: [PATCH 4/9] Draft of DataType node sub --- float-safe-optimizer/examples/add3.rise | 6 +++++- src/main/scala/rise/eqsat/NodeSubs.scala | 26 +++++++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/float-safe-optimizer/examples/add3.rise b/float-safe-optimizer/examples/add3.rise index d25c729c1..4dcc9f1ee 100644 --- a/float-safe-optimizer/examples/add3.rise +++ b/float-safe-optimizer/examples/add3.rise @@ -1,3 +1,7 @@ depFun((n: Nat) => fun((n`.`i32) ->: (n`.`i32))(in => map(add(li32(3)))(in) -)) \ No newline at end of file +<<<<<<< Updated upstream +)) +======= +)) +>>>>>>> Stashed changes diff --git a/src/main/scala/rise/eqsat/NodeSubs.scala b/src/main/scala/rise/eqsat/NodeSubs.scala index 9f993cd95..8464bb6d7 100644 --- a/src/main/scala/rise/eqsat/NodeSubs.scala +++ b/src/main/scala/rise/eqsat/NodeSubs.scala @@ -140,6 +140,9 @@ object NodeSubs { case other => egraph.add(other.map(n => replace(egraph, n, index, subs))) } } + + def replace(egraph: EGraph, id: NatId, + index: Int, subs: DataTypeId): NatId = id // nats cannot contain datatypes } object DataType { @@ -175,6 +178,17 @@ object NodeSubs { dt => replace(egraph, dt, index, subs) )) + def replace(egraph: EGraph, id: DataTypeId, + index: Int, subs: DataTypeId): DataTypeId = + egraph(id) match { + case DataTypeVar(i) if i == index => subs + case dtv: DataTypeVar => egraph.add(dtv) + case other => egraph.add(other.map( + Nat.replace(egraph, _, index, subs), + replace(egraph, _, index, subs) + )) + } + def replaceDataType(index: Int, subs: DataType): DataType = { ??? } @@ -220,9 +234,15 @@ object NodeSubs { }) def replace(egraph: EGraph, id: TypeId, - index: Int, subs: DataTypeId): TypeId = { - ??? - } + index: Int, subs: DataTypeId): TypeId = egraph.add(egraph(id) match { + case DataFunType(t) => + val t2 = replace(egraph, t, index + 1, DataType.shifted(egraph, subs, (0, 1), (0, 0))) + DataFunType(t2) + case other => other.map( + replace(egraph, _, index, subs), + Nat.replace(egraph, _, index, subs), + DataType.replace(egraph, _, index, subs)) + }) // substitutes %n0 for arg in this def withNatArgument(egraph: EGraph, From 6fa88a5cccf3fa3489a3ebd2ba4cb19c2c21b04f Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Mon, 3 Mar 2025 16:22:13 +0100 Subject: [PATCH 5/9] fix eqsat code paths for foreign function --- float-safe-optimizer/examples/add3.rise | 6 +----- float-safe-optimizer/examples/cos.rise | 3 +++ src/main/scala/rise/eqsat/Analysis.scala | 8 ++++++++ src/main/scala/rise/eqsat/NodeSubs.scala | 24 +++++++++++++++--------- src/test/scala/rise/eqsat/Basic.scala | 11 +++++++++++ 5 files changed, 38 insertions(+), 14 deletions(-) create mode 100644 float-safe-optimizer/examples/cos.rise diff --git a/float-safe-optimizer/examples/add3.rise b/float-safe-optimizer/examples/add3.rise index 4dcc9f1ee..d25c729c1 100644 --- a/float-safe-optimizer/examples/add3.rise +++ b/float-safe-optimizer/examples/add3.rise @@ -1,7 +1,3 @@ depFun((n: Nat) => fun((n`.`i32) ->: (n`.`i32))(in => map(add(li32(3)))(in) -<<<<<<< Updated upstream -)) -======= -)) ->>>>>>> Stashed changes +)) \ No newline at end of file diff --git a/float-safe-optimizer/examples/cos.rise b/float-safe-optimizer/examples/cos.rise new file mode 100644 index 000000000..3a99b0db8 --- /dev/null +++ b/float-safe-optimizer/examples/cos.rise @@ -0,0 +1,3 @@ +depFun((n: Nat) => fun(n`.`f64)(x => + x |> mapSeq(foreignFun("cos", f64 ->: f64)) +)) \ No newline at end of file diff --git a/src/main/scala/rise/eqsat/Analysis.scala b/src/main/scala/rise/eqsat/Analysis.scala index 5bd4ba6a6..174fc02c9 100644 --- a/src/main/scala/rise/eqsat/Analysis.scala +++ b/src/main/scala/rise/eqsat/Analysis.scala @@ -926,6 +926,14 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) case rp.id() => // FIXME: only supports non-functional values Seq(read ->: read, write ->: write) + case rp.foreignFunction(_, _) => + def buildAnnot(t: rise.eqsat.TypeId): TypeAnnotation = egraph(t) match { + case _: DataTypeNode[NatId, DataTypeId] => read + case rise.eqsat.FunType(in, out) => buildAnnot(in) ->: buildAnnot(out) + case rise.eqsat.DataFunType(t) => dtFunT(buildAnnot(t)) + case node => throw new Exception(s"did not expect $node") + } + Seq(buildAnnot(t)) case _ => throw new Exception(s"did not expect $p") } val beam = Seq(( diff --git a/src/main/scala/rise/eqsat/NodeSubs.scala b/src/main/scala/rise/eqsat/NodeSubs.scala index 8464bb6d7..897abd5fc 100644 --- a/src/main/scala/rise/eqsat/NodeSubs.scala +++ b/src/main/scala/rise/eqsat/NodeSubs.scala @@ -234,15 +234,21 @@ object NodeSubs { }) def replace(egraph: EGraph, id: TypeId, - index: Int, subs: DataTypeId): TypeId = egraph.add(egraph(id) match { - case DataFunType(t) => - val t2 = replace(egraph, t, index + 1, DataType.shifted(egraph, subs, (0, 1), (0, 0))) - DataFunType(t2) - case other => other.map( - replace(egraph, _, index, subs), - Nat.replace(egraph, _, index, subs), - DataType.replace(egraph, _, index, subs)) - }) + index: Int, subs: DataTypeId): TypeId = + id match { + case dt: DataTypeId => DataType.replace(egraph, dt, index, subs) + case _: NotDataTypeId => egraph.add(egraph(id) match { + case DataFunType(t) => + val t2 = replace(egraph, t, index + 1, DataType.shifted(egraph, subs, (0, 1), (0, 0))) + DataFunType(t2) + case _: DataTypeNode[NatId, DataTypeId] => + throw new Exception("this should not happen") + case other => other.map( + replace(egraph, _, index, subs), + Nat.replace(egraph, _, index, subs), + DataType.replace(egraph, _, index, subs)) + }) + } // substitutes %n0 for arg in this def withNatArgument(egraph: EGraph, diff --git a/src/test/scala/rise/eqsat/Basic.scala b/src/test/scala/rise/eqsat/Basic.scala index e93bc073f..326f5c7bf 100644 --- a/src/test/scala/rise/eqsat/Basic.scala +++ b/src/test/scala/rise/eqsat/Basic.scala @@ -126,6 +126,17 @@ class Basic extends test_util.Tests { Seq()) } + + test("foreign function") { + import rise.core.types.DataType.f32 + + ProveEquiv.init().run( + Expr.fromNamed(foreignFun("cos", f32 ->: f32)), + Expr.fromNamed(foreignFun("cos", f32 ->: f32)), + Seq(), Seq() + ) + } + ignore("saturate associativity and fusion/fission") { withFuns(4) withFuns(5) From 56f3dba8e22281f61b4c2a4f6abe3596a4fc9ed5 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Sun, 16 Mar 2025 21:14:23 +0100 Subject: [PATCH 6/9] add more examples --- float-safe-optimizer/examples/generate.rise | 3 +++ float-safe-optimizer/examples/larr.rise | 7 +++++++ float-safe-optimizer/examples/makeArray.rise | 3 +++ float-safe-optimizer/examples/scalarOut.rise | 3 +++ 4 files changed, 16 insertions(+) create mode 100644 float-safe-optimizer/examples/generate.rise create mode 100644 float-safe-optimizer/examples/larr.rise create mode 100644 float-safe-optimizer/examples/makeArray.rise create mode 100644 float-safe-optimizer/examples/scalarOut.rise diff --git a/float-safe-optimizer/examples/generate.rise b/float-safe-optimizer/examples/generate.rise new file mode 100644 index 000000000..567d05a86 --- /dev/null +++ b/float-safe-optimizer/examples/generate.rise @@ -0,0 +1,3 @@ +depFun((n: Nat) => fun(i32 ->: i32 ->: 2`.`i32)((x, y) => + generate(fun(i => select(i =:= lidx(0, 2))(x)(y))) |> mapSeq(fun(x => x)) +)) \ No newline at end of file diff --git a/float-safe-optimizer/examples/larr.rise b/float-safe-optimizer/examples/larr.rise new file mode 100644 index 000000000..e61c6d9bc --- /dev/null +++ b/float-safe-optimizer/examples/larr.rise @@ -0,0 +1,7 @@ +import rise.core.semantics._ + +def larrToMem(s: Seq[Data]) = larr(s) |> mapSeqUnroll(fun(x => x)) |> toMem + +depFun((n: Nat) => + larrToMem(Seq(IntData(1), IntData(2), IntData(3))) +) \ No newline at end of file diff --git a/float-safe-optimizer/examples/makeArray.rise b/float-safe-optimizer/examples/makeArray.rise new file mode 100644 index 000000000..8a58368d9 --- /dev/null +++ b/float-safe-optimizer/examples/makeArray.rise @@ -0,0 +1,3 @@ +depFun((n: Nat) => fun(i32 ->: i32 ->: 2`.`i32)((x, y) => + makeArray(2)(x)(y) |> mapSeqUnroll(fun(x => x)) +)) \ No newline at end of file diff --git a/float-safe-optimizer/examples/scalarOut.rise b/float-safe-optimizer/examples/scalarOut.rise new file mode 100644 index 000000000..e029fa17b --- /dev/null +++ b/float-safe-optimizer/examples/scalarOut.rise @@ -0,0 +1,3 @@ +depFun((n: Nat) => fun(n`.`f64)(x => + x |> reduceSeq(add)(lf64(0.0)) +)) \ No newline at end of file From ce09d5d25e0baba77d0c7e211dc20134842a63e5 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Sun, 16 Mar 2025 21:16:06 +0100 Subject: [PATCH 7/9] WIP on read/write bug from eqsat lowering search --- src/main/scala/rise/eqsat/Analysis.scala | 9 ++-- .../scala/rise/eqsat/LoweringSearch.scala | 45 ++++++++++++++++--- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/main/scala/rise/eqsat/Analysis.scala b/src/main/scala/rise/eqsat/Analysis.scala index 174fc02c9..a4e4667a1 100644 --- a/src/main/scala/rise/eqsat/Analysis.scala +++ b/src/main/scala/rise/eqsat/Analysis.scala @@ -615,7 +615,7 @@ object BeamExtractRW { assert(at == bt) (a, b) match { case (DataTypeAnnotation(x), DataTypeAnnotation(y)) => - (x == y) || (x == rct.read || notContainingArrayType(bt.asInstanceOf[DataTypeId], egraph)) + (x == y) || (x == rct.read && notContainingArrayType(bt.asInstanceOf[DataTypeId], egraph)) case (NotDataTypeAnnotation(x), NotDataTypeAnnotation(y)) => (x, egraph(at), y, egraph(bt)) match { case (FunType(aIn, aOut), FunType(aInT, aOutT), FunType(bIn, bOut), FunType(bInT, bOutT)) => @@ -633,11 +633,8 @@ object BeamExtractRW { // TODO: could hash-cons this def notContainingArrayType(t: DataTypeId, egraph: EGraph): Boolean = { egraph(t) match { - case DataTypeVar(_) => false - case ScalarType(_) => true - case NatType => true - case VectorType(_, _) => true - case IndexType(_) => true + case DataTypeVar(_) => false // FIXME: this requires constraint? + case ScalarType(_) | NatType | VectorType(_, _) | IndexType(_) => true case PairType(dt1, dt2) => notContainingArrayType(dt1, egraph) && notContainingArrayType(dt2, egraph) case ArrayType(_, _) => false } diff --git a/src/main/scala/rise/eqsat/LoweringSearch.scala b/src/main/scala/rise/eqsat/LoweringSearch.scala index 4a33064c9..d32d8ed4e 100644 --- a/src/main/scala/rise/eqsat/LoweringSearch.scala +++ b/src/main/scala/rise/eqsat/LoweringSearch.scala @@ -26,11 +26,25 @@ class LoweringSearch(var filter: Predicate) { } } - def run(normalForm: NF, - costFunction: CostFunction[_], - startBeam: Seq[Expr], - loweringRules: Seq[Rewrite], - annotations: Option[(BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation])] = None): Option[Expr] = { + private def topLevelSubtype( + aAnnot: (BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation]), + bAnnot: (BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation]), + typ: TypeId, + egraph: EGraph, + ) = { + val (aOut, aIns) = aAnnot + val (bOut, bIns) = bAnnot + BeamExtractRW.subtype(aOut, typ, bOut, typ, egraph) && + aIns == bIns // TODO: could use subtype here as well in contravariant fashion + } + + def run[Cost]( + normalForm: NF, + costFunction: CostFunction[Cost], + startBeam: Seq[Expr], + loweringRules: Seq[Rewrite], + annotations: Option[(BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation])] = None + ): Option[Expr] = { println("---- lowering") val egraph = EGraph.empty() val normBeam = startBeam.map(normalForm.normalize) @@ -52,9 +66,28 @@ class LoweringSearch(var filter: Predicate) { r.printReport() util.printTime("lowered extraction time", { - val tmp = Analysis.oneShot(BeamExtractRW(1, costFunction), egraph)(egraph.find(rootId)) + val analysisResult = Analysis.oneShot(BeamExtractRW(1, costFunction), egraph)(egraph.find(rootId)) + // : Map[ + // (BeamExtractRW.TypeAnnotation, Map[Int,BeamExtractRW.TypeAnnotation]), + // Seq[(Cost, ExprWithHashCons)]] + + println("analysisResult", analysisResult) + println("expectedAnnotations", expectedAnnotations) + val validResults = analysisResult + .map { case (foundAnnot, foundBeam) => (foundAnnot, foundBeam.head) } + // first, filter correct subtypes on annotations + .filter { case (foundAnnot, found) => + topLevelSubtype(foundAnnot, expectedAnnotations, found._2.t, egraph) + } + println("validResults", validResults) + validResults + // then, get the best option + .minByOption { case (_, found) => found._1 }(costFunction.ordering) + .map { case (_, found) => ExprWithHashCons.expr(egraph)(found._2) } + /* without taking subtyping into account: tmp.get(expectedAnnotations) .map { beam => ExprWithHashCons.expr(egraph)(beam.head._2) } + */ }) } } From 890abde81114db9e3c74e422641fbda4e5c08e3e Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Tue, 18 Mar 2025 10:43:16 +0100 Subject: [PATCH 8/9] add eqsat support for nat and index literals, support length implementation via type inference --- float-safe-optimizer/Main.scala | 3 +- float-safe-optimizer/Optimize.scala | 6 ++-- float-safe-optimizer/examples/length-rw.rise | 5 ++++ float-safe-optimizer/examples/length.rise | 5 ++++ src/main/scala/rise/eqsat/Analysis.scala | 2 +- src/main/scala/rise/eqsat/EGraphDot.scala | 2 ++ src/main/scala/rise/eqsat/Expr.scala | 12 ++++++-- src/main/scala/rise/eqsat/Extractor.scala | 2 +- .../scala/rise/eqsat/LoweringSearch.scala | 3 +- src/main/scala/rise/eqsat/NamedRewrite.scala | 8 +++++ src/main/scala/rise/eqsat/Node.scala | 30 +++++++++++++++---- src/main/scala/rise/eqsat/NodeSubs.scala | 11 ++++++- src/main/scala/rise/eqsat/Rewrite.scala | 2 +- src/main/scala/rise/eqsat/Runner.scala | 2 +- src/main/scala/rise/eqsat/TypeCheck.scala | 6 ++++ src/main/scala/rise/eqsat/rules.scala | 6 ++++ 16 files changed, 88 insertions(+), 17 deletions(-) create mode 100644 float-safe-optimizer/examples/length-rw.rise create mode 100644 float-safe-optimizer/examples/length.rise diff --git a/float-safe-optimizer/Main.scala b/float-safe-optimizer/Main.scala index 58e3fdd1c..0e4ab6197 100644 --- a/float-safe-optimizer/Main.scala +++ b/float-safe-optimizer/Main.scala @@ -13,8 +13,9 @@ object Main { val exprSource = util.readFile(exprSourcePath) val untypedExpr = parseExpr(prefixImports(exprSource)) val typedExpr = untypedExpr.toExpr + println("typedExpr", typedExpr) val optimizedExpr = Optimize(typedExpr) - println(optimizedExpr) + println("optimizedExpr", optimizedExpr) val code = gen.openmp.function.asStringFromExpr(optimizedExpr) util.writeToPath(outputPath, code) } diff --git a/float-safe-optimizer/Optimize.scala b/float-safe-optimizer/Optimize.scala index 4a0c7733e..b6381ed34 100644 --- a/float-safe-optimizer/Optimize.scala +++ b/float-safe-optimizer/Optimize.scala @@ -6,6 +6,7 @@ object Optimize { def apply(e: rise.core.Expr): rise.core.Expr = { val expr = Expr.fromNamed(e) val (body, annotation, wrapBody) = analyseTopLevel(expr) + println("eqsat expr", expr) val rules = { import rise.eqsat.rules._ @@ -16,11 +17,12 @@ object Optimize { // satisfying read/write annotations: mapSeqArray, // simplifications: - mapFusion, - reduceSeqMapFusion, + // mapFusion, + // reduceSeqMapFusion, removeTransposePair, fstReduction, sndReduction, + reduceSeqOne, /* maybe: omp.mapPar --> need heuristic vs mapSeq toMemAfterMapSeq / storeToMem diff --git a/float-safe-optimizer/examples/length-rw.rise b/float-safe-optimizer/examples/length-rw.rise new file mode 100644 index 000000000..ecb4ed10f --- /dev/null +++ b/float-safe-optimizer/examples/length-rw.rise @@ -0,0 +1,5 @@ +def length() = reduceSeq(fun(acc => fun(x => acc + li32(1))))(li32(0)) + +depFun((n: Nat) => fun(n`.`i32)(x => + x |> mapSeq(fun(y => y + length()(x))) +)) \ No newline at end of file diff --git a/float-safe-optimizer/examples/length.rise b/float-safe-optimizer/examples/length.rise new file mode 100644 index 000000000..f82f45012 --- /dev/null +++ b/float-safe-optimizer/examples/length.rise @@ -0,0 +1,5 @@ +def length() = impl { n: Nat => impl{ dt: DataType => fun(n`.`dt)(x => l(n)) } } + +depFun((n: Nat) => fun(n`.`i32)(x => + x |> mapSeq(fun(y => y + (cast(length()(x)) :: i32))) +)) \ No newline at end of file diff --git a/src/main/scala/rise/eqsat/Analysis.scala b/src/main/scala/rise/eqsat/Analysis.scala index b9d8cfbf6..ed63b17f5 100644 --- a/src/main/scala/rise/eqsat/Analysis.scala +++ b/src/main/scala/rise/eqsat/Analysis.scala @@ -815,7 +815,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) // note: recording DataFunType() constructor is useless (NotDataTypeAnnotation(AddrFunType(annotation)), env) -> newBeam } - case Literal(_) => + case Literal(_) | NatLiteral(_) | IndexLiteral(_, _) => val beam = Seq(( cf.cost(egraph, enode, t, Map.empty), ExprWithHashCons(enode.mapChildren(Map.empty), t) diff --git a/src/main/scala/rise/eqsat/EGraphDot.scala b/src/main/scala/rise/eqsat/EGraphDot.scala index c50bb6449..6e640647b 100644 --- a/src/main/scala/rise/eqsat/EGraphDot.scala +++ b/src/main/scala/rise/eqsat/EGraphDot.scala @@ -113,6 +113,8 @@ case class EGraphDot(egraph: EGraph, case AddrApp(_, a) => s"aApp $a" case AddrLambda(_) => "Λ : addr" case Literal(d) => s"$d" + case NatLiteral(n) => s"$n" + case IndexLiteral(i, n) => s"idx($i, $n)" case Primitive(p) => s"${p.toString.trim}" case Composition(_, _) => ">>" } diff --git a/src/main/scala/rise/eqsat/Expr.scala b/src/main/scala/rise/eqsat/Expr.scala index c823bc65d..e12b05e49 100644 --- a/src/main/scala/rise/eqsat/Expr.scala +++ b/src/main/scala/rise/eqsat/Expr.scala @@ -160,14 +160,16 @@ object Expr { DataApp(fromNamed(f, bound), DataType.fromNamed(dt, bound)) case core.DepApp(rct.AddressSpaceKind, f, a: rct.AddressSpace) => AddrApp(fromNamed(f, bound), Address.fromNamed(a, bound)) - case core.DepApp(_, _, _) => ??? + case core.DepApp(k, _, _) => throw new Exception(s"missing DepApp case for $k") case core.DepLambda(rct.NatKind, n: rct.NatIdentifier, e) => NatLambda(fromNamed(e, bound + n)) case core.DepLambda(rct.DataKind, dt: rcdt.DataTypeIdentifier, e) => DataLambda(fromNamed(e, bound + dt)) case core.DepLambda(rct.AddressSpaceKind, a: rct.AddressSpaceIdentifier, e) => AddrLambda(fromNamed(e, bound + a)) - case core.DepLambda(_, _, _) => ??? + case core.DepLambda(k, _, _) => throw new Exception(s"missing DepLambda case for $k") + case core.Literal(core.semantics.NatData(n)) => NatLiteral(Nat.fromNamed(n, bound)) + case core.Literal(core.semantics.IndexData(i, n)) => IndexLiteral(Nat.fromNamed(i, bound), Nat.fromNamed(n, bound)) case core.Literal(d) => Literal(d) // note: we set the primitive type to a place holder here, // because we do not want type information at the node level @@ -202,6 +204,9 @@ object Expr { core.DepLambda(rct.AddressSpaceKind, i, toNamed(e, bound + i)) _ case Literal(d) => core.Literal(d).setType _ case Primitive(p) => p.setType _ + case NatLiteral(n) => core.Literal(core.semantics.NatData(Nat.toNamed(n, bound))).setType _ + case IndexLiteral(i, n) => core.Literal(core.semantics.IndexData( + Nat.toNamed(i, bound), Nat.toNamed(n, bound))).setType _ case Composition(f, g) => /* val f2 = f.shifted((1, 0, 0), (0, 0, 0)) @@ -251,6 +256,9 @@ object Expr { val i = rct.AddressSpaceIdentifier(s"a${bound.data.size}") core.DepLambda(rct.AddressSpaceKind, i, rec(e, bound + i)) _ case Literal(d) => core.Literal(d).setType _ + case NatLiteral(n) => core.Literal(core.semantics.NatData(Nat.toNamed(n, bound))).setType _ + case IndexLiteral(i, n) => core.Literal(core.semantics.IndexData( + Nat.toNamed(i, bound), Nat.toNamed(n, bound))).setType _ case Primitive(p) => p.setType _ case Composition(f, g) => /* diff --git a/src/main/scala/rise/eqsat/Extractor.scala b/src/main/scala/rise/eqsat/Extractor.scala index c18058190..b692b58b4 100644 --- a/src/main/scala/rise/eqsat/Extractor.scala +++ b/src/main/scala/rise/eqsat/Extractor.scala @@ -202,7 +202,7 @@ case class BENFRedexCount(/*egraph: EGraph*/) extends CostFunction[BENFRedexCoun case AddrLambda(e) => val ed = costs(e) Data(ed.redexes, ed.free, isEtaApp = false, isLam = false, isNatLam = false) - case Literal(_) | Primitive(_) => + case Literal(_) | NatLiteral(_) | IndexLiteral(_, _) | Primitive(_) => Data(0, Set(), isEtaApp = false, isLam = false, isNatLam = false) case Composition(f, g) => val fd = costs(f) diff --git a/src/main/scala/rise/eqsat/LoweringSearch.scala b/src/main/scala/rise/eqsat/LoweringSearch.scala index 4a33064c9..a5c58aaf6 100644 --- a/src/main/scala/rise/eqsat/LoweringSearch.scala +++ b/src/main/scala/rise/eqsat/LoweringSearch.scala @@ -34,6 +34,7 @@ class LoweringSearch(var filter: Predicate) { println("---- lowering") val egraph = EGraph.empty() val normBeam = startBeam.map(normalForm.normalize) + println(s"normalized: $normBeam") val expectedAnnotations = annotations match { case Some(annotations) => annotations @@ -48,7 +49,7 @@ class LoweringSearch(var filter: Predicate) { .withTimeLimit(java.time.Duration.ofMinutes(1)) .withMemoryLimit(4L * 1024L * 1024L * 1024L /* 4GiB */) .withNodeLimit(1_000_000) - .run(egraph, filter, loweringRules, normalForm.directedRules, Seq(rootId)) + .run(egraph, filter, loweringRules, Seq()/*normalForm.directedRules*/, Seq(rootId)) r.printReport() util.printTime("lowered extraction time", { diff --git a/src/main/scala/rise/eqsat/NamedRewrite.scala b/src/main/scala/rise/eqsat/NamedRewrite.scala index 46c05c538..b2e744067 100644 --- a/src/main/scala/rise/eqsat/NamedRewrite.scala +++ b/src/main/scala/rise/eqsat/NamedRewrite.scala @@ -150,6 +150,10 @@ object NamedRewrite { makePat(f, bound, isRhs, matchType = false), makeAPat(x, bound, isRhs))) case rc.DepApp(_, _, _) => ??? + case rc.Literal(rc.semantics.NatData(n)) => + PatternNode(NatLiteral(makeNPat(n, bound, isRhs))) + case rc.Literal(rc.semantics.IndexData(i, n)) => + PatternNode(IndexLiteral(makeNPat(i, bound, isRhs), makeNPat(n, bound, isRhs))) case rc.Literal(d) => PatternNode(Literal(d)) // note: we set the primitive type to a place holder here, // because we do not want type information at the node level @@ -567,7 +571,9 @@ object NamedRewriteDSL { } def l(d: rc.semantics.Data): Pattern = rc.Literal(d) def lf32(f: Float): Pattern = l(rise.core.semantics.FloatData(f)) + def li32(i: Int): Pattern = app(cast, l(rise.core.semantics.IntData(i))) def lidx(i: Int, n: Int) = l(rise.core.semantics.IndexData(i, n)) + def lnat(n: rct.Nat) = l(rise.core.semantics.NatData(n)) def slide: Pattern = rcp.slide.primitive def map: Pattern = rcp.map.primitive @@ -591,6 +597,7 @@ object NamedRewriteDSL { def asVector: Pattern = rcp.asVector.primitive def asVectorAligned: Pattern = rcp.asVectorAligned.primitive def vectorFromScalar: Pattern = rcp.vectorFromScalar.primitive + def cast: Pattern = rcp.cast.primitive def `?n`: NatPattern = rct.NatIdentifier(rc.freshName("n")) @@ -612,6 +619,7 @@ object NamedRewriteDSL { rct.TypeIdentifier(name) val int: DataTypePattern = rcdt.int + val i32: DataTypePattern = rcdt.i32 val f32: DataTypePattern = rcdt.f32 implicit final class TypeAnnotation(private val t: TypePattern) extends AnyVal { diff --git a/src/main/scala/rise/eqsat/Node.scala b/src/main/scala/rise/eqsat/Node.scala index c84e1cb8a..9d40aeb05 100644 --- a/src/main/scala/rise/eqsat/Node.scala +++ b/src/main/scala/rise/eqsat/Node.scala @@ -25,6 +25,8 @@ sealed trait Node[+E, +N, +DT, +A] { case DataLambda(e) => DataLambda(fe(e)) case AddrApp(f, x) => AddrApp(fe(f), fa(x)) case AddrLambda(e) => AddrLambda(fe(e)) + case NatLiteral(n) => NatLiteral(fn(n)) + case IndexLiteral(i, n) => IndexLiteral(fn(i), fn(n)) case Composition(f, g) => Composition(fe(f), fe(g)) } @@ -32,7 +34,8 @@ sealed trait Node[+E, +N, +DT, +A] { def mapChildren[OE](fc: E => OE): Node[OE, N, DT, A] = map(fc, n => n, dt => dt, a => a) def children(): Iterator[E] = this match { - case Var(_) | Literal(_) | Primitive(_) => Iterator() + case Var(_) | Literal(_) | NatLiteral(_) | IndexLiteral(_, _) | + Primitive(_) => Iterator() case App(f, e) => Iterator(f, e) case Lambda(e) => Iterator(e) case NatApp(f, _) => Iterator(f) @@ -49,6 +52,8 @@ sealed trait Node[+E, +N, +DT, +A] { def nats(): Iterator[N] = this match { case NatApp(_, n) => Iterator(n) + case NatLiteral(n) => Iterator(n) + case IndexLiteral(i, n) => Iterator(i, n) case _ => Iterator() } def natsCount(): Int = nats().length @@ -76,6 +81,8 @@ sealed trait Node[+E, +N, +DT, +A] { case AddrApp(_, _) => 7 case AddrLambda(_) => 8 case Literal(d) => 17 * d.hashCode() + case NatLiteral(n) => 23 * n.hashCode() + case IndexLiteral(i, n) => 29 * (i, n).hashCode() case Primitive(p) => 19 * p.setType(rct.TypePlaceholder).hashCode() case Composition(_, _) => 9 @@ -94,6 +101,8 @@ sealed trait Node[+E, +N, +DT, +A] { case (DataLambda(_), DataLambda(_)) => true case (AddrLambda(_), AddrLambda(_)) => true case (Literal(d1), Literal(d2)) => d1 == d2 + case (NatLiteral(n1), NatLiteral(n2)) => n1 == n2 + case (IndexLiteral(i1, n1), IndexLiteral(i2, n2)) => i1 == i2 && n1 == n2 case (Primitive(p1), Primitive(p2)) => // TODO: type should not be inside the primitive? p1.setType(rct.TypePlaceholder) == p2.setType(rct.TypePlaceholder) @@ -117,6 +126,8 @@ case class AddrLambda[E](e: E) extends Node[E, Nothing, Nothing, Nothing] case class Literal(d: semantics.Data) extends Node[Nothing, Nothing, Nothing, Nothing] { override def toString: String = d.toString } +case class NatLiteral[N](n: N) extends Node[Nothing, N, Nothing, Nothing] +case class IndexLiteral[N](i: N, n: N) extends Node[Nothing, N, Nothing, Nothing] case class Primitive(p: rise.core.Primitive) extends Node[Nothing, Nothing, Nothing, Nothing] { override def toString: String = p.toString.trim } @@ -139,6 +150,8 @@ object Node { case DataLambda(e) => Seq(e) case AddrLambda(e) => Seq(e) case Literal(_) => Seq() + case NatLiteral(n) => Seq(n) + case IndexLiteral(i, n) => Seq(i, n) case Primitive(_) => Seq() case Composition(f, g) => Seq(f, g) @@ -201,13 +214,11 @@ object Node { def compare(d1: Data, d2: Data): Int = (d1, d2) match { - case (NatData(n1), NatData(n2)) => ??? // natOrdering.compare(n1, n2) - case (NatData(_), _) => ??? // -1 - case (_, NatData(_)) => ??? // 1 + // case (NatData(n1), NatData(n2)) => ??? // natOrdering.compare(n1, n2) + case (NatData(_), _) | (_, NatData(_)) => throw new Exception("should not compare 'NatData'") case (IndexData(i1, n1), IndexData(i2, n2)) => ??? // implicitly[Ordering[(Nat, Nat)]].compare((i1, n1), (i2, n2)) - case (IndexData(_, _), _) => ??? // -1 - case (_, IndexData(_, _)) => ??? // 1 + case (IndexData(_, _), _) | (_, IndexData(_, _)) => throw new Exception("should not compare 'IndexData'") case (sd1: ScalarData, sd2: ScalarData) => scalarDataOrdering.compare(sd1, sd2) case (_: ScalarData, _) => -1 case (_, _: ScalarData) => 1 @@ -293,6 +304,13 @@ object Node { case (Literal(d1), Literal(d2)) => dataOrdering.compare(d1, d2) case (Literal(_), _) => -1 case (_, Literal(_)) => 1 + case (NatLiteral(n1), NatLiteral(n2)) => nOrd.compare(n1, n2) + case (NatLiteral(_), _) => -1 + case (_, NatLiteral(_)) => 1 + case (IndexLiteral(i1, n1), IndexLiteral(i2, n2)) => + implicitly[Ordering[(N, N)]].compare((i1, n1), (i2, n2)) + case (IndexLiteral(_, _), _) => -1 + case (_, IndexLiteral(_, _)) => 1 case (Composition(f1, g1), Composition(f2, g2)) => implicitly[Ordering[(E, E)]].compare((f1, g1), (f2, g2)) diff --git a/src/main/scala/rise/eqsat/NodeSubs.scala b/src/main/scala/rise/eqsat/NodeSubs.scala index 897abd5fc..d5e1fc4e7 100644 --- a/src/main/scala/rise/eqsat/NodeSubs.scala +++ b/src/main/scala/rise/eqsat/NodeSubs.scala @@ -1,4 +1,5 @@ package rise.eqsat +import rise.eqsat.NatLiteral object NodeSubs { /** Shifts De-Bruijn indices up or down if they are >= cutoff @@ -31,6 +32,10 @@ object NodeSubs { case AddrApp(f, x) => AddrApp(shiftedE(f, shift, cutoff), Address.shifted(x, shift._4, cutoff._4)) + case NatLiteral(n) => NatLiteral(Nat.shifted(egraph, n, shift._2, cutoff._2)) + case IndexLiteral(i, n) => IndexLiteral( + Nat.shifted(egraph, i, shift._2, cutoff._2), + Nat.shifted(egraph, n, shift._2, cutoff._2)) case Literal(_) | Primitive(_) => n case Composition(f, g) => @@ -43,7 +48,7 @@ object NodeSubs { (shiftedE: (E, Expr.Shift, Expr.Shift) => E): E = n match { case Var(idx) if idx == index => subs - case Var(_) | Literal(_) | Primitive(_) => makeE(n) + case Var(_) | Literal(_) | NatLiteral(_) | IndexLiteral(_, _) | Primitive(_) => makeE(n) case Lambda(e) => // TODO: could shift lazily val e2 = replaceE(e, index + 1, shiftedE(subs, (1, 0, 0, 0), (0, 0, 0, 0))) @@ -100,6 +105,10 @@ object NodeSubs { AddrLambda(replaceE(e, index, subs)) case AddrApp(f, x) => AddrApp(replaceE(f, index, subs), x) + case NatLiteral(n) => + NatLiteral(Nat.replace(egraph, n, index, subs)) + case IndexLiteral(i, n) => + IndexLiteral(Nat.replace(egraph, i, index, subs), Nat.replace(egraph, n, index, subs)) case Composition(f, g) => Composition(replaceE(f, index, subs), replaceE(g, index, subs)) diff --git a/src/main/scala/rise/eqsat/Rewrite.scala b/src/main/scala/rise/eqsat/Rewrite.scala index 772cdacc0..61962f596 100644 --- a/src/main/scala/rise/eqsat/Rewrite.scala +++ b/src/main/scala/rise/eqsat/Rewrite.scala @@ -508,7 +508,7 @@ case class VectorizeScalarFunExtractApplier(f: PatternVar, n: NatPatternVar, fV: case NatLambda(_) => None case DataLambda(_) => None case AddrLambda(_) => None - case Literal(_) => + case Literal(_) | NatLiteral(_) | IndexLiteral(_, _) => for { tv <- vecDT(expr.t, n, eg) } yield ExprWithHashCons(App( ExprWithHashCons(Primitive(rcp.vectorFromScalar.primitive), eg.add(FunType(expr.t, tv))), diff --git a/src/main/scala/rise/eqsat/Runner.scala b/src/main/scala/rise/eqsat/Runner.scala index 4519033cd..5aa5da7b3 100644 --- a/src/main/scala/rise/eqsat/Runner.scala +++ b/src/main/scala/rise/eqsat/Runner.scala @@ -135,7 +135,7 @@ class Runner(var iterations: Vec[Iteration], if (stopReasons.nonEmpty) { return end() } val iter = runOne(egraph, roots, filter, rules, normRules) - // println(iter) + println(iter) if (iter.applied.isEmpty && scheduler.canSaturate(iterations.size)) { diff --git a/src/main/scala/rise/eqsat/TypeCheck.scala b/src/main/scala/rise/eqsat/TypeCheck.scala index 0016d9bd8..eb2279d1b 100644 --- a/src/main/scala/rise/eqsat/TypeCheck.scala +++ b/src/main/scala/rise/eqsat/TypeCheck.scala @@ -97,6 +97,12 @@ object TypeCheck { case Literal(d) => // TODO: more efficient egraph.addTypeFromNamed? assertSameType(t, egraph.addType(Type.fromNamed(d.dataType))) + case NatLiteral(_) => + () + // TODO: assertSameType(t, egraph.addType(NatType)) + case IndexLiteral(_, _) => + () + // TODO: assertSameType(t, egraph.addType(IndexType(?))) case Primitive(_) => // TODO: check p.typeScheme consistency? diff --git a/src/main/scala/rise/eqsat/rules.scala b/src/main/scala/rise/eqsat/rules.scala index fdc5d6b7c..1179d88d1 100644 --- a/src/main/scala/rise/eqsat/rules.scala +++ b/src/main/scala/rise/eqsat/rules.scala @@ -612,6 +612,12 @@ object rules { nApp(nApp(slide, "sz"), 1) --> app(nApp(rcp.rotateValues.primitive, "sz"), lam("x", "x")) ) + val reduceSeqOne = NamedRewrite.init("reduce-seq-one", + (app(app(app(rcp.reduceSeq.primitive, lam("acc", lam("x", app(app(add, "acc"), li32(1))))), li32(0)), "in" :: ("n" : Nat)`.``?dt`) :: i32) + --> + (app(cast, lnat("n")) :: i32) + ) + object omp { val mapPar = NamedRewrite.init("map-par", map --> rise.openMP.primitives.mapPar.primitive From 9374d06964ae85b94b6526b7ca8c41733c656db4 Mon Sep 17 00:00:00 2001 From: Thomas Koehler Date: Tue, 18 Mar 2025 16:21:02 +0100 Subject: [PATCH 9/9] fix read/write analysis bug --- src/main/scala/rise/eqsat/Analysis.scala | 14 +++++++++++--- src/main/scala/rise/eqsat/LoweringSearch.scala | 16 ++++++++++++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/main/scala/rise/eqsat/Analysis.scala b/src/main/scala/rise/eqsat/Analysis.scala index 89bbccbfe..668bcc8dd 100644 --- a/src/main/scala/rise/eqsat/Analysis.scala +++ b/src/main/scala/rise/eqsat/Analysis.scala @@ -572,8 +572,14 @@ object BeamExtractRW { sealed trait TypeAnnotation case class NotDataTypeAnnotation(node: TypeNode[TypeAnnotation, Unit, rct.Access]) extends TypeAnnotation + { + override def toString: String = node.toString() + } case class DataTypeAnnotation(access: rct.Access) extends TypeAnnotation + { + override def toString: String = access.toString() + } type Data[Cost] = Map[(TypeAnnotation, Map[Int, TypeAnnotation]), Seq[(Cost, ExprWithHashCons)]] @@ -613,7 +619,7 @@ object BeamExtractRW { def subtype(a: TypeAnnotation, at: TypeId, b: TypeAnnotation, bt: TypeId, egraph: EGraph): Boolean = { assert(at == bt) - (a, b) match { + val res = (a, b) match { case (DataTypeAnnotation(x), DataTypeAnnotation(y)) => (x == y) || (x == rct.read && notContainingArrayType(bt.asInstanceOf[DataTypeId], egraph)) case (NotDataTypeAnnotation(x), NotDataTypeAnnotation(y)) => @@ -628,12 +634,14 @@ object BeamExtractRW { } case _ => throw new Exception("this should not happen") } + // println(s"subtype: $a : ${egraph(at)} <= $b : ${egraph(bt)} ? $res") + res } // TODO: could hash-cons this def notContainingArrayType(t: DataTypeId, egraph: EGraph): Boolean = { egraph(t) match { - case DataTypeVar(_) => false // FIXME: this requires constraint? + case DataTypeVar(_) => false case ScalarType(_) | NatType | VectorType(_, _) | IndexType(_) => true case PairType(dt1, dt2) => notContainingArrayType(dt1, egraph) && notContainingArrayType(dt2, egraph) case ArrayType(_, _) => false @@ -696,7 +704,7 @@ case class BeamExtractRW[Cost](beamSize: Int, cf: CostFunction[Cost]) case NotDataTypeAnnotation(FunType(fIn, fOut)) => eBeams.foreach { case ((eAnnotation, eEnv), eBeam) => mergeEnv(fEnv, eEnv).foreach { mergedEnv => - if (subtype(fIn, fInT, eAnnotation, eT, egraph)) { + if (subtype(eAnnotation, eT, fIn, fInT, egraph)) { val newBeam = fBeam.flatMap { x => eBeam.flatMap { y => Seq(( cf.cost(egraph, enode, t, Map(f -> x._1, e -> y._1)), diff --git a/src/main/scala/rise/eqsat/LoweringSearch.scala b/src/main/scala/rise/eqsat/LoweringSearch.scala index 41c6e1844..a43e680e1 100644 --- a/src/main/scala/rise/eqsat/LoweringSearch.scala +++ b/src/main/scala/rise/eqsat/LoweringSearch.scala @@ -67,20 +67,28 @@ class LoweringSearch(var filter: Predicate) { r.printReport() util.printTime("lowered extraction time", { - val analysisResult = Analysis.oneShot(BeamExtractRW(1, costFunction), egraph)(egraph.find(rootId)) + val allAnalysisResult = Analysis.oneShot(BeamExtractRW(1, costFunction), egraph) // : Map[ // (BeamExtractRW.TypeAnnotation, Map[Int,BeamExtractRW.TypeAnnotation]), // Seq[(Cost, ExprWithHashCons)]] - println("analysisResult", analysisResult) - println("expectedAnnotations", expectedAnnotations) + /* DEBUG: + println("allAnalysisResult") + allAnalysisResult.foreach { case (id, map) => + println(id, ";", egraph.classes(id).nodes, ":", map) + } + println("----------") */ + + val analysisResult = allAnalysisResult(egraph.find(rootId)) + // DEBUG: println("analysisResult", analysisResult) + // DEBUG: println("expectedAnnotations", expectedAnnotations) val validResults = analysisResult .map { case (foundAnnot, foundBeam) => (foundAnnot, foundBeam.head) } // first, filter correct subtypes on annotations .filter { case (foundAnnot, found) => topLevelSubtype(foundAnnot, expectedAnnotations, found._2.t, egraph) } - println("validResults", validResults) + // DEBUG: println("validResults", validResults) validResults // then, get the best option .minByOption { case (_, found) => found._1 }(costFunction.ordering)