Skip to content

User guided specialisation stage #252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: hkmc2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
322b68b
Miscellaneous Fixes (#280)
CAG2Mark Feb 13, 2025
740dacb
Ported specialisation keyword to HEAD
Oli-Ar Feb 12, 2025
1796d7d
Added logic to pass specialisation flags through elaborator
Oli-Ar Feb 13, 2025
af9ac51
Brought back specialiser logic
Oli-Ar Feb 13, 2025
b91aa15
Pushed majority of logic for the simple sub algorithm
Oli-Ar Feb 24, 2025
423fafe
Pulled class info out from SimpleType to new class
Oli-Ar Mar 6, 2025
7189d2c
Fixed type inference for classes and some types of functions
Oli-Ar Mar 6, 2025
3211cec
Fixed relationship
Oli-Ar Mar 7, 2025
138ecca
Code cleanup and minor fixes
Oli-Ar Mar 7, 2025
9fa8b52
Added type inference for statements
Oli-Ar Mar 7, 2025
7ae6090
Delayed processing of TermDefs fixing some class of inference bugs
Oli-Ar Mar 7, 2025
89ca14c
Specialisation inference
Oli-Ar Mar 7, 2025
f6bd7bb
Generating specialised functions
Oli-Ar Mar 7, 2025
4aa2131
Arguments to specialised parameters are now extracted
Oli-Ar Mar 8, 2025
d0ae2e6
Cleaned out tests
Oli-Ar Mar 8, 2025
73b350a
Specialiser type pattern matching added
Oli-Ar Mar 14, 2025
7f4b2a5
Specialiser type pattern matching added
Oli-Ar Mar 14, 2025
2b235bb
Renaming
Oli-Ar Mar 15, 2025
952e08f
Handling specialisation inside specialised functions better
Oli-Ar Mar 15, 2025
e57dfcd
Improved specificity of specialiser
Oli-Ar Mar 18, 2025
ed681da
Added typeof function check support
Oli-Ar Mar 27, 2025
0a795a7
Small features & Bug Fixes
Oli-Ar Mar 27, 2025
c08a954
Re-added removed test
Oli-Ar Mar 29, 2025
ae4dfb1
WIP: Changed specilisation to work on specialised parameter vectors
Oli-Ar Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class MLsCompiler(preludeFile: os.Path, mkOutput: ((Str => Unit) => Unit) => Uni

// TODO adapt logic
val etl = new TraceLogger{override def doTrace: Bool = false}
val stl = new TraceLogger{override def doTrace: Bool = false}
val ltl = new TraceLogger{override def doTrace: Bool = false}


Expand Down Expand Up @@ -79,14 +80,16 @@ class MLsCompiler(preludeFile: os.Path, mkOutput: ((Str => Unit) => Unit) => Uni
newCtx.nest(N).givenIn:
val elab = Elaborator(etl, wd, newCtx)
val parsed = mainParse.resultBlk
val (blk0, _) = elab.importFrom(parsed)
val blk = blk0.copy(stats = semantics.Import(State.runtimeSymbol, runtimeFile.toString) :: blk0.stats)
val (blk0, ctx) = elab.importFrom(parsed)
val blk: semantics.Term.Blk = blk0.copy(stats = semantics.Import(State.runtimeSymbol, runtimeFile.toString) :: blk0.stats)
val typ = new semantics.Specialiser(ctx, stl)
val spBlk = typ.topLevel(blk)
val low = ltl.givenIn:
new codegen.Lowering()
with codegen.LoweringSelSanityChecks
val jsb = ltl.givenIn:
codegen.js.JSBuilder()
val le = low.program(blk)
val le = low.program(spBlk)
val baseScp: utils.Scope =
utils.Scope.empty
val nestedScp = baseScp.nest
Expand Down
3 changes: 2 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,14 @@ sealed abstract class Block extends Product with AutoLocated:
case AssignDynField(_, _, _, rhs, rest) => rhs.subBlocks ::: rest :: Nil
case Define(d, rest) => d.subBlocks ::: rest :: Nil
case HandleBlock(_, _, par, args, _, handlers, body, rest) => par.subBlocks ++ args.flatMap(_.subBlocks) ++ handlers.map(_.body) :+ body :+ rest
case Label(_, body, rest) => body :: rest :: Nil

// TODO rm Lam from values and thus the need for these cases
case Return(r, _) => r.subBlocks
case HandleBlockReturn(r) => r.subBlocks
case Throw(r) => r.subBlocks

case _: Return | _: Throw | _: Label | _: Break | _: Continue | _: End | _: HandleBlockReturn => Nil
case _: Return | _: Throw | _: Break | _: Continue | _: End | _: HandleBlockReturn => Nil

// Moves definitions in a block to the top. Only scans the top-level definitions of the block;
// i.e, definitions inside other definitions are not moved out. Definitions inside `match`/`if`
Expand Down
14 changes: 7 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ class BlockTransformer(subst: SymbolSubst):
(cls2 is cls) && (hdr2 is hdr) && (bod2 is bod) && (rst2 is rst)
then b else HandleBlock(l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
val lhs2 = applyPath(lhs)
val fld2 = applyPath(fld)
val rhs2 = applyResult(rhs)
val rest2 = applyBlock(rest)
if (lhs2 is lhs) && (fld2 is fld) && (rhs2 is rhs) && (rest2 is rest)
then b
else AssignDynField(lhs2, fld2, arrayIdx, rhs2, rest2)
applyResult2(rhs): rhs2 =>
val lhs2 = applyPath(lhs)
val fld2 = applyPath(fld)
val rest2 = applyBlock(rest)
if (lhs2 is lhs) && (fld2 is fld) && (rhs2 is rhs) && (rest2 is rest)
then b
else AssignDynField(lhs2, fld2, arrayIdx, rhs2, rest2)


def applyResult2(r: Result)(k: Result => Block): Block = k(applyResult(r))
Expand Down
14 changes: 10 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx):
.assignFieldN(state.res, tailIdent, state.res.tail.next)
.ret(state.res))
private val functionHandlerCtx = funcLikeHandlerCtx(N)
private val topLevelCtx = HandlerCtx(true, true, N, _ => rtThrowMsg("Unhandled effects"))
private def ctorCtx(ctorThis: Path) = funcLikeHandlerCtx(S(ctorThis))
private def handlerCtx(using HandlerCtx): HandlerCtx = summon
private val predefPath: Path = State.globalThisSymbol.asPath.selN(Tree.Ident("Predef"))
Expand Down Expand Up @@ -267,6 +268,9 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx):
case blk @ AssignField(lhs, nme, rhs, rest) =>
val PartRet(head, parts) = go(rest)
PartRet(AssignField(lhs, nme, rhs, head)(blk.symbol), parts)
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
val PartRet(head, parts) = go(rest)
PartRet(AssignDynField(lhs, fld, arrayIdx, rhs, head), parts)
case Return(_, _) => PartRet(blk, Nil)
// ignored cases
case TryBlock(sub, finallyDo, rest) => ??? // ignore
Expand Down Expand Up @@ -336,7 +340,7 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx):
override def applyLam(lam: Value.Lam): Value.Lam = Value.Lam(lam.params, translateBlock(lam.body, functionHandlerCtx))
override def applyDefn(defn: Defn): Defn = defn match
case f: FunDefn => translateFun(f)
case c: ClsLikeDefn => translateCls(c)
case c: ClsLikeDefn => translateCls(c, handlerCtx.isTopLevel)
case _: ValDefn => super.applyDefn(defn)
transformer.applyBlock(b)

Expand All @@ -356,9 +360,11 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx):
private def translateFun(f: FunDefn): FunDefn =
FunDefn(f.owner, f.sym, f.params, translateBlock(f.body, functionHandlerCtx))

private def translateCls(cls: ClsLikeDefn): ClsLikeDefn =
private def translateCls(cls: ClsLikeDefn, isTopLevel: Bool): ClsLikeDefn =
val curCtorCtx = if isTopLevel && (cls.k is syntax.Mod) then topLevelCtx else
ctorCtx(cls.sym.asClsLike.getOrElse(wat("asClsLike", cls.sym)).asPath)
cls.copy(methods = cls.methods.map(translateFun),
ctor = translateBlock(cls.ctor, ctorCtx(cls.sym.asClsLike.getOrElse(wat("asClsLike", cls.sym)).asPath)))
ctor = translateBlock(cls.ctor, curCtorCtx))

// Handle block becomes a FunDefn and CallPlaceholder
private def translateHandleBlock(h: HandleBlock): Block =
Expand Down Expand Up @@ -548,5 +554,5 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx):
transform.applyBlock(b)

def translateTopLevel(b: Block): Block =
translateBlock(b, HandlerCtx(true, true, N, _ => rtThrowMsg("Unhandled effects")))
translateBlock(b, topLevelCtx)

Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class StackSafeTransform(depthLimit: Int)(using State):
override def applyFunDefn(fun: FunDefn): FunDefn = rewriteFn(fun)

override def applyDefn(defn: Defn): Defn = defn match
case defn: ClsLikeDefn => rewriteCls(defn)
case defn: ClsLikeDefn => rewriteCls(defn, isTopLevel)
case _: FunDefn | _: ValDefn => super.applyDefn(defn)

override def applyBlock(b: Block): Block = b match
Expand Down Expand Up @@ -143,12 +143,13 @@ class StackSafeTransform(depthLimit: Int)(using State):
walker.applyBlock(b)
trivial

def rewriteCls(defn: ClsLikeDefn): ClsLikeDefn =
def rewriteCls(defn: ClsLikeDefn, isTopLevel: Bool): ClsLikeDefn =
val ClsLikeDefn(owner, isym, sym, k, paramsOpt,
parentPath, methods, privateFields, publicFields, preCtor, ctor) = defn
ClsLikeDefn(
owner, isym, sym, k, paramsOpt, parentPath, methods.map(rewriteFn), privateFields,
publicFields, rewriteBlk(preCtor), rewriteBlk(ctor)
publicFields, rewriteBlk(preCtor),
if isTopLevel && (defn.k is syntax.Mod) then transformTopLevel(ctor) else rewriteBlk(ctor)
)

def rewriteBlk(blk: Block) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
case Elaborator.ctx.builtins.Num => doc"typeof $sd === 'number'"
case Elaborator.ctx.builtins.Bool => doc"typeof $sd === 'boolean'"
case Elaborator.ctx.builtins.Int => doc"globalThis.Number.isInteger($sd)"
case Elaborator.ctx.builtins.Function => doc"typeof $sd === 'function'"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this as I had a test where I specialise a parameter which I pass a lambda to, so when I generate the pattern matching for the application I needed to be able to match the lambda, but I'd just like to check that this is okay to do since I'm not very familiar with the codegen side of mlscript.

case _ => doc"$sd instanceof ${result(pth)}"
case Case.Tup(len, inf) => doc"globalThis.Array.isArray($sd) && $sd.length ${if inf then ">=" else "==="} ${len}"
val h = doc" # if (${ cond(hd._1) }) ${ braced(returningTerm(hd._2, endSemi = false)) }"
Expand Down
2 changes: 2 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,8 @@ extends Importer:
ps
case TypeDef(Pat, inner, N, N) =>
param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(pat = true))))
case Modified(Keyword.`spec`, _, inner) =>
param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(spec = true))))
case _ =>
t.asParam(inUsing).map: (isSpd, p, t) =>
isSpd -> Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term(_)))
Expand Down
Loading