Skip to content

Commit 03b824e

Browse files
committed
fix nested pattern matching
1 parent c3dbace commit 03b824e

File tree

3 files changed

+167
-48
lines changed

3 files changed

+167
-48
lines changed

src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs

+68-48
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where
22

33
import Data.HashMap.Strict qualified as HashMap
44
import Data.HashSet qualified as HashSet
5-
import Data.List.NonEmpty.Extra qualified as NonEmpty
65
import Data.Text qualified as T
76
import Data.Text qualified as Text
87
import Juvix.Compiler.Backend.Isabelle.Data.Result
@@ -95,19 +94,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
9594
mkExprCase c@Case {..} = case _caseValue of
9695
ExprIden v ->
9796
case _caseBranches of
98-
CaseBranch {..} :| [] ->
97+
CaseBranch {..} :| _ ->
9998
case _caseBranchPattern of
10099
PatVar v' -> substVar v' v _caseBranchBody
101100
_ -> ExprCase c
102-
_ -> ExprCase c
103101
ExprTuple (Tuple (ExprIden v :| [])) ->
104102
case _caseBranches of
105-
CaseBranch {..} :| [] ->
103+
CaseBranch {..} :| _ ->
106104
case _caseBranchPattern of
107105
PatTuple (Tuple (PatVar v' :| [])) -> substVar v' v _caseBranchBody
108106
_ -> ExprCase c
109-
_ -> ExprCase c
110-
_ -> ExprCase c
107+
_ ->
108+
case _caseBranches of
109+
br@CaseBranch {..} :| _ ->
110+
case _caseBranchPattern of
111+
PatVar _ ->
112+
ExprCase
113+
Case
114+
{ _caseValue = _caseValue,
115+
_caseBranches = br :| []
116+
}
117+
PatTuple (Tuple (PatVar _ :| [])) ->
118+
ExprCase
119+
Case
120+
{ _caseValue = _caseValue,
121+
_caseBranches = br :| []
122+
}
123+
_ -> ExprCase c
111124

112125
goMutualBlock :: Internal.MutualBlock -> [Statement]
113126
goMutualBlock Internal.MutualBlock {..} =
@@ -243,24 +256,25 @@ goModule onlyTypes infoTable Internal.Module {..} =
243256
: goClauses cls
244257
Nested pats npats ->
245258
let rhs = goExpression'' nset' nmap' _lambdaBody
246-
argnames' = fmap getPatternArgName _lambdaPatterns
259+
argnames' = fmap getPatternArgName lambdaPats
247260
vnames =
248-
fmap
249-
( \(idx :: Int, mname) ->
250-
maybe
251-
( defaultName
252-
(getLoc cl)
253-
( disambiguate
254-
(nset' ^. nameSet)
255-
("v_" <> show idx)
256-
)
257-
)
258-
(overNameText (disambiguate (nset' ^. nameSet)))
259-
mname
260-
)
261-
(NonEmpty.zip (nonEmpty' [0 ..]) argnames')
261+
nonEmpty' $
262+
fmap
263+
( \(idx :: Int, mname) ->
264+
maybe
265+
( defaultName
266+
(getLoc cl)
267+
( disambiguate
268+
(nset' ^. nameSet)
269+
("v_" <> show idx)
270+
)
271+
)
272+
(overNameText (disambiguate (nset' ^. nameSet)))
273+
mname
274+
)
275+
(zip [0 ..] argnames')
262276
nset'' = foldl' (flip (over nameSet . HashSet.insert . (^. namePretty))) nset' vnames
263-
remainingBranches = goLambdaClauses'' nset'' nmap' cls
277+
remainingBranches = goLambdaClauses'' nset'' nmap' (Just ty) cls
264278
valTuple = ExprTuple (Tuple (fmap ExprIden vnames))
265279
patTuple = PatTuple (Tuple (nonEmpty' pats))
266280
brs = goNestedBranches (getLoc cl) valTuple rhs remainingBranches patTuple (nonEmpty' npats)
@@ -275,7 +289,8 @@ goModule onlyTypes infoTable Internal.Module {..} =
275289
}
276290
]
277291
where
278-
(npats0, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns))
292+
lambdaPats = filterTypeArgs 0 ty (toList _lambdaPatterns)
293+
(npats0, nset', nmap') = goPatternArgsTop lambdaPats
279294
[] -> []
280295

281296
goNestedBranches :: Interval -> Expression -> Expression -> [CaseBranch] -> Pattern -> NonEmpty (Expression, Nested Pattern) -> NonEmpty CaseBranch
@@ -828,18 +843,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
828843
| patsNum == 0 = goExpression (head _lambdaClauses ^. Internal.lambdaBody)
829844
| otherwise = goLams vars
830845
where
831-
patsNum =
832-
case _lambdaType of
833-
Just ty ->
834-
length
835-
. filterTypeArgs 0 ty
836-
. toList
837-
$ head _lambdaClauses ^. Internal.lambdaPatterns
838-
Nothing ->
839-
length
840-
. filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit))
841-
. toList
842-
$ head _lambdaClauses ^. Internal.lambdaPatterns
846+
patsNum = length $ filterLambdaPatternArgs _lambdaType $ head _lambdaClauses ^. Internal.lambdaPatterns
843847
vars = map (\i -> defaultName (getLoc lam) ("x" <> show i)) [0 .. patsNum - 1]
844848

845849
goLams :: [Name] -> Sem r Expression
@@ -869,7 +873,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
869873
Tuple
870874
{ _tupleComponents = nonEmpty' vars'
871875
}
872-
brs <- goLambdaClauses (toList _lambdaClauses)
876+
brs <- goLambdaClauses _lambdaType (toList _lambdaClauses)
873877
return $
874878
mkExprCase
875879
Case
@@ -926,17 +930,29 @@ goModule onlyTypes infoTable Internal.Module {..} =
926930
Internal.CaseBranchRhsExpression e -> goExpression e
927931
Internal.CaseBranchRhsIf {} -> error "unsupported: side conditions"
928932

929-
goLambdaClauses'' :: NameSet -> NameMap -> [Internal.LambdaClause] -> [CaseBranch]
930-
goLambdaClauses'' nset nmap cls =
931-
run $ runReader nset $ runReader nmap $ goLambdaClauses cls
932-
933-
goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => [Internal.LambdaClause] -> Sem r [CaseBranch]
934-
goLambdaClauses = \case
933+
filterLambdaPatternArgs :: Maybe Internal.Expression -> NonEmpty Internal.PatternArg -> [Internal.PatternArg]
934+
filterLambdaPatternArgs mty cls = case mty of
935+
Just ty ->
936+
filterTypeArgs 0 ty
937+
. toList
938+
$ cls
939+
Nothing ->
940+
filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit))
941+
. toList
942+
$ cls
943+
944+
goLambdaClauses'' :: NameSet -> NameMap -> Maybe Internal.Expression -> [Internal.LambdaClause] -> [CaseBranch]
945+
goLambdaClauses'' nset nmap mty cls =
946+
run $ runReader nset $ runReader nmap $ goLambdaClauses mty cls
947+
948+
goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => Maybe Internal.Expression -> [Internal.LambdaClause] -> Sem r [CaseBranch]
949+
goLambdaClauses mty = \case
935950
cl@Internal.LambdaClause {..} : cls -> do
936-
(npat, nset, nmap) <- case _lambdaPatterns of
937-
p :| [] -> goPatternArgCase p
951+
let lambdaPats = filterLambdaPatternArgs mty _lambdaPatterns
952+
(npat, nset, nmap) <- case lambdaPats of
953+
[p] -> goPatternArgCase p
938954
_ -> do
939-
(npats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns)
955+
(npats, nset, nmap) <- goPatternArgsCase lambdaPats
940956
let npat =
941957
fmap
942958
( \pats ->
@@ -950,7 +966,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
950966
case npat of
951967
Nested pat [] -> do
952968
body <- withLocalNames nset nmap $ goExpression _lambdaBody
953-
brs <- goLambdaClauses cls
969+
brs <- goLambdaClauses mty cls
954970
return $
955971
CaseBranch
956972
{ _caseBranchPattern = pat,
@@ -961,7 +977,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
961977
let vname = defaultName (getLoc cl) (disambiguate (nset ^. nameSet) "v")
962978
nset' = over nameSet (HashSet.insert (vname ^. namePretty)) nset
963979
rhs <- withLocalNames nset' nmap $ goExpression _lambdaBody
964-
remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses cls
980+
remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses mty cls
965981
let brs' = goNestedBranches (getLoc vname) (ExprIden vname) rhs remainingBranches pat (nonEmpty' npats)
966982
return
967983
[ CaseBranch
@@ -1133,7 +1149,11 @@ goModule onlyTypes infoTable Internal.Module {..} =
11331149
case HashMap.lookup name (infoTable ^. Internal.infoConstructors) of
11341150
Just ctrInfo
11351151
| ctrInfo ^. Internal.constructorInfoRecord ->
1136-
Just (indName, goRecordFields (getArgtys ctrInfo) args)
1152+
case HashMap.lookup indName (infoTable ^. Internal.infoInductives) of
1153+
Just indInfo
1154+
| length (indInfo ^. Internal.inductiveInfoConstructors) == 1 ->
1155+
Just (indName, goRecordFields (getArgtys ctrInfo) args)
1156+
_ -> Nothing
11371157
where
11381158
indName = ctrInfo ^. Internal.constructorInfoInductive
11391159
_ -> Nothing

tests/positive/Isabelle/Program.juvix

+41
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,44 @@ funR4 : R -> R
193193
bf (b1 b2 : Bool) : Bool := not (b1 && b2);
194194

195195
nf (n1 n2 : Int) : Bool := n1 - n2 >= n1 || n2 <= n1 + n2;
196+
197+
-- Nested record patterns
198+
199+
type MessagePacket (MessageType : Type) : Type := mkMessagePacket {
200+
target : Nat;
201+
mailbox : Maybe Nat;
202+
message : MessageType;
203+
};
204+
205+
type EnvelopedMessage (MessageType : Type) : Type :=
206+
mkEnvelopedMessage {
207+
sender : Maybe Nat;
208+
packet : MessagePacket MessageType;
209+
};
210+
211+
type Timer (HandleType : Type): Type := mkTimer {
212+
time : Nat;
213+
handle : HandleType;
214+
};
215+
216+
type Trigger (MessageType : Type) (HandleType : Type) :=
217+
| MessageArrived { envelope : EnvelopedMessage MessageType; }
218+
| Elapsed { timers : List (Timer HandleType) };
219+
220+
getMessageFromTrigger : {M H : Type} -> Trigger M H -> Maybe M
221+
| (MessageArrived@{
222+
envelope := (mkEnvelopedMessage@{
223+
packet := (mkMessagePacket@{
224+
message := m })})})
225+
:= just m
226+
| _ := nothing;
227+
228+
229+
getMessageFromTrigger' {M H} (t : Trigger M H) : Maybe M :=
230+
case t of
231+
| (MessageArrived@{
232+
envelope := (mkEnvelopedMessage@{
233+
packet := (mkMessagePacket@{
234+
message := m })})})
235+
:= just m
236+
| _ := nothing;

tests/positive/Isabelle/isabelle/Program.thy

+58
Original file line numberDiff line numberDiff line change
@@ -241,4 +241,62 @@ fun bf :: "bool \<Rightarrow> bool \<Rightarrow> bool" where
241241
fun nf :: "int \<Rightarrow> int \<Rightarrow> bool" where
242242
"nf n1 n2 = (n1 - n2 \<ge> n1 \<or> n2 \<le> n1 + n2)"
243243

244+
(* Nested record patterns *)
245+
record 'MessageType MessagePacket =
246+
target :: nat
247+
mailbox :: "nat option"
248+
message :: 'MessageType
249+
250+
record 'MessageType EnvelopedMessage =
251+
sender :: "nat option"
252+
packet :: "'MessageType MessagePacket"
253+
254+
record 'HandleType Timer =
255+
time :: nat
256+
handle :: 'HandleType
257+
258+
datatype ('MessageType, 'HandleType) Trigger
259+
= MessageArrived "'MessageType EnvelopedMessage" |
260+
Elapsed "('HandleType Timer) list"
261+
262+
fun target :: "'MessageType MessagePacket \<Rightarrow> nat" where
263+
"target (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
264+
target'"
265+
266+
fun mailbox :: "'MessageType MessagePacket \<Rightarrow> nat option" where
267+
"mailbox (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
268+
mailbox'"
269+
270+
fun message :: "'MessageType MessagePacket \<Rightarrow> 'MessageType" where
271+
"message (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
272+
message'"
273+
274+
fun sender :: "'MessageType EnvelopedMessage \<Rightarrow> nat option" where
275+
"sender (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) = sender'"
276+
277+
fun packet :: "'MessageType EnvelopedMessage \<Rightarrow> 'MessageType MessagePacket" where
278+
"packet (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) = packet'"
279+
280+
fun time :: "'HandleType Timer \<Rightarrow> nat" where
281+
"time (| Timer.time = time', Timer.handle = handle' |) = time'"
282+
283+
fun handle :: "'HandleType Timer \<Rightarrow> 'HandleType" where
284+
"handle (| Timer.time = time', Timer.handle = handle' |) = handle'"
285+
286+
fun getMessageFromTrigger :: "('M, 'H) Trigger \<Rightarrow> 'M option" where
287+
"getMessageFromTrigger v_0 =
288+
(case (v_0) of
289+
(MessageArrived v') \<Rightarrow>
290+
(case (EnvelopedMessage.packet v') of
291+
(v'0) \<Rightarrow> Some (MessagePacket.message v'0)) |
292+
v'1 \<Rightarrow> None)"
293+
294+
fun getMessageFromTrigger' :: "('M, 'H) Trigger \<Rightarrow> 'M option" where
295+
"getMessageFromTrigger' t =
296+
(case t of
297+
(MessageArrived v') \<Rightarrow>
298+
(case (EnvelopedMessage.packet v') of
299+
(v'0) \<Rightarrow> Some (MessagePacket.message v'0)) |
300+
v'2 \<Rightarrow> None)"
301+
244302
end

0 commit comments

Comments
 (0)