Skip to content

Commit 6d2f3fe

Browse files
committed
Case compilation algorithm in SML
This is from https://julesjacobs.com/notes/patternmatching/patternmatching.pdf It compiles arbitrarily complicated/nested patterns into ones where constructors are only applied to variable arguments. It keeps underscore patterns, which can be compiled away when you have a complete list of all constructors to hand. This will guide the actual algorithm that gets written in HOL.
1 parent ce0adec commit 6d2f3fe

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed

compiler/parsing/cases-algo.ML

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
structure cases =
2+
struct
3+
4+
datatype pat = pv of string | pc of string * pat list | pUS
5+
6+
datatype exp =
7+
eLet of string * string * exp | eBase of int | eNoMatch |
8+
eMatch of string * (pat * exp) list
9+
10+
11+
type problem = ((string * pat) list * exp) list
12+
13+
fun A(p1,p2) = pc("Add", [p1, p2])
14+
fun M(p1,p2) = pc("Mul", [p1, p2])
15+
fun S p = pc ("Succ", [p])
16+
val Z = pc ("Zero", [])
17+
18+
val jjeg1:problem =
19+
[([("a", A(Z, Z))], eBase 1),
20+
([("a", M(Z, pv "X"))], eBase 2),
21+
([("a", A(S(pv "X"), pv "Y"))], eBase 3),
22+
([("a", M(pv "X", Z))], eBase 4),
23+
([("a", M(A(pv "X", pv "Y"), pv "Z"))], eBase 5),
24+
([("a", A(pv "X", Z))], eBase 6),
25+
([("a", pv "X")], eBase 7)]
26+
27+
fun push_var (eqns, rhs) =
28+
let fun foldthis (eqn, (eqns, rhs)) =
29+
case eqn of
30+
(v, pv pnm) => (eqns, eLet (pnm,v,rhs))
31+
| _ => (eqn::eqns, rhs)
32+
val (eqs', rhs') = List.foldl foldthis ([],rhs) eqns
33+
in
34+
(List.rev eqs', rhs')
35+
end
36+
37+
fun push_vars (p : problem) : problem = map push_var p
38+
39+
fun pat_arity (pc (_, args)) = List.length args
40+
| pat_arity _ = raise Fail "pat_arity on p. variable"
41+
fun pat_con (pc(cnm, _)) = cnm
42+
| pat_con _ = raise Fail "pat_con on p. variable"
43+
44+
fun pluck P [] = NONE
45+
| pluck P (h::t) = if P h then SOME (h, t)
46+
else Option.map (fn (x,t) => (x, h::t)) (pluck P t)
47+
48+
49+
fun lift testvar cnm vars (p : problem) : problem * problem =
50+
let
51+
fun lift1 (c as (eqns, rhs), (A,B)) =
52+
case pluck (fn (tv, p) => tv = testvar) eqns of
53+
NONE => (c::A, c::B)
54+
| SOME ((_, pc(cnm', args')), other_tests : (string * pat) list) =>
55+
if cnm' = cnm then
56+
((ListPair.zip (vars, args') @ other_tests, rhs) :: A, B)
57+
else (A, c::B)
58+
| SOME ((_, pv _), _) => raise Fail "lift1: found pat-var binding"
59+
val (A,B) = List.foldl lift1 ([], []) p
60+
in
61+
(List.rev A, List.rev B)
62+
end
63+
64+
fun bumpany k e m =
65+
case Binarymap.peek(m,k) of
66+
NONE => Binarymap.insert(m,k,(1,e))
67+
| SOME (c,e0) => Binarymap.insert(m,k,(c+1,e0))
68+
fun bumpex k m =
69+
case Binarymap.peek(m,k) of
70+
NONE => m
71+
| SOME (c,e) => Binarymap.insert(m,k,(c+1,e))
72+
73+
fun maxcount M =
74+
let fun foldthis (k, ce, NONE) = SOME ce
75+
| foldthis (k, ce as (c,e), A as SOME (c0,e0)) =
76+
if c > c0 then SOME ce else A
77+
in
78+
Binarymap.foldl foldthis NONE M
79+
end
80+
81+
fun heuristic eqns rest =
82+
let val M0 = Binarymap.mkDict (pair_compare(String.compare, String.compare))
83+
fun foldthis (e as (vnm, pc(cnm, _)), M) = bumpany (vnm,cnm) e M
84+
| foldthis (_, M) = M
85+
val M1 = List.foldl foldthis M0 eqns
86+
fun foldthis2 (e as (vnm, pc(cnm, _)), M) = bumpex (vnm,cnm) M
87+
| foldthis2 (_, M) = M
88+
val M2 = List.foldl
89+
(fn ((es,exp), A) => List.foldl foldthis2 A es) M1 rest
90+
in
91+
#2 (valOf (maxcount M2))
92+
end
93+
94+
95+
fun get_firstbranch (p0 : problem) =
96+
let val p = push_vars p0
97+
in
98+
case p of
99+
([], rhs) :: rest => rhs
100+
| (eqns, rhs) :: rest =>
101+
let val (tvar, pat) = heuristic eqns rest
102+
val newvars =
103+
List.tabulate(pat_arity pat, (fn i => tvar ^ Int.toString i))
104+
val cnm = pat_con pat
105+
val patarg_vector = map pv newvars
106+
val pat1 = pc(cnm, patarg_vector)
107+
val (A, B) = lift tvar cnm newvars p
108+
in
109+
eMatch (tvar, [(pc (cnm, patarg_vector), get_firstbranch A),
110+
(pUS, get_firstbranch B)])
111+
end
112+
| [] => eNoMatch
113+
end
114+
115+
fun updlast [] rep = rep
116+
| updlast [h] rep = rep
117+
| updlast (h::t) rep = h::updlast t rep
118+
119+
fun merge_dumbUS e =
120+
case e of
121+
eMatch (testv1, pes) =>
122+
let val pes' = map (apsnd merge_dumbUS) pes
123+
in
124+
case last pes' of
125+
(pUS, eMatch (testv2, uspes)) =>
126+
if testv1 = testv2 then
127+
eMatch (testv1, updlast pes' uspes)
128+
else eMatch (testv1, pes')
129+
| _ => eMatch (testv1, pes')
130+
end
131+
| eLet(v1,v2,e) => eLet(v1,v2,merge_dumbUS e)
132+
| _ => e
133+
134+
val jjeg2 : problem =
135+
[([("a", A (A(pv "X", pv "Y"), Z))], eBase 1),
136+
([("a", A (M(pv "X", pv "Y"), Z))], eBase 2),
137+
([("a", A (pv "X", M(pv "Y", pv "Z")))], eBase 3),
138+
([("a", A (pv "X", A(pv "Y", pv "Z")))], eBase 4),
139+
([("a", A (pv "X", Z))], eBase 5)]
140+
141+
val sol2 = merge_dumbUS $ get_firstbranch jjeg2
142+
143+
fun uniq_pfx p slist =
144+
case List.filter (String.isPrefix p) slist of
145+
[] => p
146+
| ss => uniq_pfx (p ^ "%") ss
147+
148+
end

0 commit comments

Comments
 (0)