Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fat JAR executable that generates C code from Rise programs #240

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ modules.xml
*.gz
*.sc

float-safe-optimizer.jar
*.o
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,20 @@ The source code for the compiler is organised into sub-packages of the `shine` p

### Setup and Documentation
Please have a look at: https://rise-lang.org/doc/

### Float Safe Optimizer

This repository contains an optimizer executable that preserves floating-point semantics.
To build a Fat JAR executable:
```sh
sbt float_safe_optimizer/assembly
```
To optimize a Rise program and generate code:
```sh
java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar $function_name $rise_source_path $output_path
```
For example:
```sh
java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar add3 float-safe-optimizer/examples/add3Seq.rise float-safe-optimizer/examples/add3Seq.c
java -Xss20m -Xms512m -Xmx4G -jar float-safe-optimizer.jar add3 float-safe-optimizer/examples/add3.rise float-safe-optimizer/examples/add3.c
```
22 changes: 21 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,25 @@ clap := {
"echo y" #| (baseDirectory.value + "/lib/clap/buildClap.sh") !
}


lazy val float_safe_optimizer = (project in file("float-safe-optimizer"))
.dependsOn(riseAndShine)
.enablePlugins(AssemblyPlugin)
.settings(
excludeDependencies ++= Seq(
ExclusionRule("org.scala-lang.modules", s"scala-xml_${scalaBinaryVersion.value}"),
ExclusionRule("junit", "junit"),
ExclusionRule("com.novocode", "junit-interface"),
ExclusionRule("org.scalacheck", "scalacheck"),
ExclusionRule("org.scalatest", "scalatest"),
ExclusionRule("com.lihaoyi", s"os-lib_${scalaBinaryVersion.value}"),
ExclusionRule("com.typesafe.play", s"play-json_${scalaBinaryVersion.value}"),
ExclusionRule("org.rise-lang", s"opencl-executor_${scalaBinaryVersion.value}"),
ExclusionRule("org.rise-lang", "CUexecutor"),
ExclusionRule("org.elevate-lang", s"cuda-executor_${scalaBinaryVersion.value}"),
ExclusionRule("org.elevate-lang", s"meta_${scalaBinaryVersion.value}"),
),
name := "float-safe-optimizer",
javaOptions ++= Seq("-Xss20m", "-Xms512m", "-Xmx4G"),
assemblyOutputPath in assembly := file("float-safe-optimizer.jar"),
)

43 changes: 43 additions & 0 deletions float-safe-optimizer/Main.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package float_safe_optimizer

import util.gen
import rise.core.Expr
import rise.core.DSL.ToBeTyped

object Main {
def main(args: Array[String]): Unit = {
val name = args(0)
val exprSourcePath = args(1)
val outputPath = args(2)

val exprSource = util.readFile(exprSourcePath)
val untypedExpr = parseExpr(prefixImports(exprSource))
val typedExpr = untypedExpr.toExpr
println("typedExpr", typedExpr)
val optimizedExpr = Optimize(typedExpr)
println("optimizedExpr", optimizedExpr)
val code = gen.openmp.function.asStringFromExpr(optimizedExpr)
util.writeToPath(outputPath, code)
}

def prefixImports(source: String): String =
s"""
|import rise.core.DSL._
|import rise.core.DSL.Type._
|import rise.core.DSL.HighLevelConstructs._
|import rise.core.primitives._
|import rise.core.types._
|import rise.core.types.DataType._
|import rise.openmp.DSL._
|import rise.openmp.primitives._
|$source
|""".stripMargin

def parseExpr(source: String): ToBeTyped[Expr] = {
import scala.reflect.runtime.universe
import scala.tools.reflect.ToolBox

val toolbox = universe.runtimeMirror(getClass.getClassLoader).mkToolBox()
toolbox.eval(toolbox.parse(source)).asInstanceOf[ToBeTyped[Expr]]
}
}
100 changes: 100 additions & 0 deletions float-safe-optimizer/Optimize.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package float_safe_optimizer

import rise.eqsat._

object Optimize {
def apply(e: rise.core.Expr): rise.core.Expr = {
val expr = Expr.fromNamed(e)
val (body, annotation, wrapBody) = analyseTopLevel(expr)
println("eqsat expr", expr)

val rules = {
import rise.eqsat.rules._
Seq(
// implementation choices:
reduceSeq,
mapSeq,
// satisfying read/write annotations:
mapSeqArray,
// simplifications:
// mapFusion,
// reduceSeqMapFusion,
removeTransposePair,
fstReduction,
sndReduction,
reduceSeqOne,
/* maybe:
omp.mapPar --> need heuristic vs mapSeq
toMemAfterMapSeq / storeToMem
storeToMem
reduceSeqMapFusion
mapSeqUnroll/reduceSeqUnroll --> need heuristic
eliminateMapIdentity

is it worth the cost?:
betaExtract
betaNatExtract
eta

not generic enough, use Elevate passes or custom applier?:
idxReduction_i_n
*/
)
}

LoweringSearch.init().run(BENF, Cost, Seq(body), rules, Some(annotation)) match {
case Some(resBody) =>
val res = wrapBody(resBody)
Expr.toNamed(res)
case None => throw new Exception("could not find valid low-level expression")
}
}

// TODO: this code might be avoidable by making DPIA+codegen rely on top-level type instead of top-level constructs
def analyseTopLevel(e: Expr)
: (Expr, (BeamExtractRW.TypeAnnotation, Map[Int, BeamExtractRW.TypeAnnotation]), Expr => Expr) = {
import rise.eqsat.RWAnnotationDSL._

// returns (body, argCount, wrapBody)
def rec(e: Expr): (Expr, Int, Expr => Expr) = {
e.node match {
case NatLambda(e2) =>
val (b, a, w) = rec(e2)
(b, a, b => Expr(NatLambda(w(b)), e.t))
case DataLambda(e2) =>
val (b, a, w) = rec(e2)
(b, a, b => Expr(DataLambda(w(b)), e.t))
case Lambda(e2) =>
e.t.node match {
case FunType(Type(_: DataTypeNode[_, _]), _) => ()
case _ => throw new Exception("top level higher-order functions are not supported")
}
val (b, a, w) = rec(e2)
(b, a + 1, b => Expr(Lambda(w(b)), e.t))
case _ =>
if (!e.t.node.isInstanceOf[DataTypeNode[_, _]]) {
throw new Exception("expected body with data type")
}
(e, 0, b => b)
}
}

val (b, a, w) = rec(e)
(b, (write, List.tabulate(a)(i => i -> read).toMap), w)
}

object Cost extends CostFunction[Int] {
val ordering = implicitly

override def cost(egraph: EGraph, enode: ENode, t: TypeId, costs: EClassId => Int): Int = {
import rise.core.primitives._

val nodeCost = enode match {
// prefer avoiding mapSeq
case Primitive(mapSeq()) => 5
case _ => 1
}
enode.children().foldLeft(nodeCost) { case (acc, eclass) => acc + costs(eclass) }
}
}
}
1 change: 1 addition & 0 deletions float-safe-optimizer/examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.c
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/add3.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun((n`.`i32) ->: (n`.`i32))(in =>
map(add(li32(3)))(in)
))
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/add3Seq.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun((n`.`i32) ->: (n`.`i32))(in =>
mapSeq(add(li32(3)))(in)
))
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/add3TypeError.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun((n`.`i32) ->: (n`.`f32))(in =>
map(add(li32(3)))(in)
))
4 changes: 4 additions & 0 deletions float-safe-optimizer/examples/addv.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
depFun((n: Nat) => depFun((m: Nat) => depFun((o: Nat) =>
fun(((n+o)`.`i32) ->: ((m+o)`.`i32) ->: (o`.`i32))((a, b) =>
zip(take(o)(a))(take(o)(b)) |> map(fun(x => fst(x) + snd(x)))
))))
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/cos.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun(n`.`f64)(x =>
x |> mapSeq(foreignFun("cos", f64 ->: f64))
))
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/generate.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun(i32 ->: i32 ->: 2`.`i32)((x, y) =>
generate(fun(i => select(i =:= lidx(0, 2))(x)(y))) |> mapSeq(fun(x => x))
))
7 changes: 7 additions & 0 deletions float-safe-optimizer/examples/larr.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import rise.core.semantics._

def larrToMem(s: Seq[Data]) = larr(s) |> mapSeqUnroll(fun(x => x)) |> toMem

depFun((n: Nat) =>
larrToMem(Seq(IntData(1), IntData(2), IntData(3)))
)
5 changes: 5 additions & 0 deletions float-safe-optimizer/examples/length-rw.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def length() = reduceSeq(fun(acc => fun(x => acc + li32(1))))(li32(0))

depFun((n: Nat) => fun(n`.`i32)(x =>
x |> mapSeq(fun(y => y + length()(x)))
))
5 changes: 5 additions & 0 deletions float-safe-optimizer/examples/length.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def length() = impl { n: Nat => impl{ dt: DataType => fun(n`.`dt)(x => l(n)) } }

depFun((n: Nat) => fun(n`.`i32)(x =>
x |> mapSeq(fun(y => y + (cast(length()(x)) :: i32)))
))
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/makeArray.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun(i32 ->: i32 ->: 2`.`i32)((x, y) =>
makeArray(2)(x)(y) |> mapSeqUnroll(fun(x => x))
))
3 changes: 3 additions & 0 deletions float-safe-optimizer/examples/scalarOut.rise
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
depFun((n: Nat) => fun(n`.`f64)(x =>
x |> reduceSeq(add)(lf64(0.0))
))
2 changes: 2 additions & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.0-RC1")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.3.0")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.2.17")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.15.0")
addDependencyTreePlugin
Loading
Loading