Skip to content

Commit 42bd0a8

Browse files
committed
fix nested pattern matching
1 parent d855023 commit 42bd0a8

File tree

3 files changed

+167
-48
lines changed

3 files changed

+167
-48
lines changed

src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs

+68-48
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where
22

33
import Data.HashMap.Strict qualified as HashMap
44
import Data.HashSet qualified as HashSet
5-
import Data.List.NonEmpty.Extra qualified as NonEmpty
65
import Data.Text qualified as T
76
import Data.Text qualified as Text
87
import Juvix.Compiler.Backend.Isabelle.Data.Result
@@ -95,19 +94,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
9594
mkExprCase c@Case {..} = case _caseValue of
9695
ExprIden v ->
9796
case _caseBranches of
98-
CaseBranch {..} :| [] ->
97+
CaseBranch {..} :| _ ->
9998
case _caseBranchPattern of
10099
PatVar v' -> substVar v' v _caseBranchBody
101100
_ -> ExprCase c
102-
_ -> ExprCase c
103101
ExprTuple (Tuple (ExprIden v :| [])) ->
104102
case _caseBranches of
105-
CaseBranch {..} :| [] ->
103+
CaseBranch {..} :| _ ->
106104
case _caseBranchPattern of
107105
PatTuple (Tuple (PatVar v' :| [])) -> substVar v' v _caseBranchBody
108106
_ -> ExprCase c
109-
_ -> ExprCase c
110-
_ -> ExprCase c
107+
_ ->
108+
case _caseBranches of
109+
br@CaseBranch {..} :| _ ->
110+
case _caseBranchPattern of
111+
PatVar _ ->
112+
ExprCase
113+
Case
114+
{ _caseValue = _caseValue,
115+
_caseBranches = br :| []
116+
}
117+
PatTuple (Tuple (PatVar _ :| [])) ->
118+
ExprCase
119+
Case
120+
{ _caseValue = _caseValue,
121+
_caseBranches = br :| []
122+
}
123+
_ -> ExprCase c
111124

112125
goMutualBlock :: Internal.MutualBlock -> [Statement]
113126
goMutualBlock Internal.MutualBlock {..} =
@@ -243,24 +256,25 @@ goModule onlyTypes infoTable Internal.Module {..} =
243256
: goClauses cls
244257
Nested pats npats ->
245258
let rhs = goExpression'' nset' nmap' _lambdaBody
246-
argnames' = fmap getPatternArgName _lambdaPatterns
259+
argnames' = fmap getPatternArgName lambdaPats
247260
vnames =
248-
fmap
249-
( \(idx :: Int, mname) ->
250-
maybe
251-
( defaultName
252-
(getLoc cl)
253-
( disambiguate
254-
(nset' ^. nameSet)
255-
("v_" <> show idx)
256-
)
257-
)
258-
(overNameText (disambiguate (nset' ^. nameSet)))
259-
mname
260-
)
261-
(NonEmpty.zip (nonEmpty' [0 ..]) argnames')
261+
nonEmpty' $
262+
fmap
263+
( \(idx :: Int, mname) ->
264+
maybe
265+
( defaultName
266+
(getLoc cl)
267+
( disambiguate
268+
(nset' ^. nameSet)
269+
("v_" <> show idx)
270+
)
271+
)
272+
(overNameText (disambiguate (nset' ^. nameSet)))
273+
mname
274+
)
275+
(zip [0 ..] argnames')
262276
nset'' = foldl' (flip (over nameSet . HashSet.insert . (^. namePretty))) nset' vnames
263-
remainingBranches = goLambdaClauses'' nset'' nmap' cls
277+
remainingBranches = goLambdaClauses'' nset'' nmap' (Just ty) cls
264278
valTuple = ExprTuple (Tuple (fmap ExprIden vnames))
265279
patTuple = PatTuple (Tuple (nonEmpty' pats))
266280
brs = goNestedBranches (getLoc cl) valTuple rhs remainingBranches patTuple (nonEmpty' npats)
@@ -275,7 +289,8 @@ goModule onlyTypes infoTable Internal.Module {..} =
275289
}
276290
]
277291
where
278-
(npats0, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns))
292+
lambdaPats = filterTypeArgs 0 ty (toList _lambdaPatterns)
293+
(npats0, nset', nmap') = goPatternArgsTop lambdaPats
279294
[] -> []
280295

281296
goNestedBranches :: Interval -> Expression -> Expression -> [CaseBranch] -> Pattern -> NonEmpty (Expression, Nested Pattern) -> NonEmpty CaseBranch
@@ -828,18 +843,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
828843
| patsNum == 0 = goExpression (head _lambdaClauses ^. Internal.lambdaBody)
829844
| otherwise = goLams vars
830845
where
831-
patsNum =
832-
case _lambdaType of
833-
Just ty ->
834-
length
835-
. filterTypeArgs 0 ty
836-
. toList
837-
$ head _lambdaClauses ^. Internal.lambdaPatterns
838-
Nothing ->
839-
length
840-
. filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit))
841-
. toList
842-
$ head _lambdaClauses ^. Internal.lambdaPatterns
846+
patsNum = length $ filterLambdaPatternArgs _lambdaType $ head _lambdaClauses ^. Internal.lambdaPatterns
843847
vars = map (\i -> defaultName (getLoc lam) ("x" <> show i)) [0 .. patsNum - 1]
844848

845849
goLams :: [Name] -> Sem r Expression
@@ -869,7 +873,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
869873
Tuple
870874
{ _tupleComponents = nonEmpty' vars'
871875
}
872-
brs <- goLambdaClauses (toList _lambdaClauses)
876+
brs <- goLambdaClauses _lambdaType (toList _lambdaClauses)
873877
return $
874878
mkExprCase
875879
Case
@@ -926,17 +930,29 @@ goModule onlyTypes infoTable Internal.Module {..} =
926930
Internal.CaseBranchRhsExpression e -> goExpression e
927931
Internal.CaseBranchRhsIf {} -> error "unsupported: side conditions"
928932

929-
goLambdaClauses'' :: NameSet -> NameMap -> [Internal.LambdaClause] -> [CaseBranch]
930-
goLambdaClauses'' nset nmap cls =
931-
run $ runReader nset $ runReader nmap $ goLambdaClauses cls
932-
933-
goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => [Internal.LambdaClause] -> Sem r [CaseBranch]
934-
goLambdaClauses = \case
933+
filterLambdaPatternArgs :: Maybe Internal.Expression -> NonEmpty Internal.PatternArg -> [Internal.PatternArg]
934+
filterLambdaPatternArgs mty cls = case mty of
935+
Just ty ->
936+
filterTypeArgs 0 ty
937+
. toList
938+
$ cls
939+
Nothing ->
940+
filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit))
941+
. toList
942+
$ cls
943+
944+
goLambdaClauses'' :: NameSet -> NameMap -> Maybe Internal.Expression -> [Internal.LambdaClause] -> [CaseBranch]
945+
goLambdaClauses'' nset nmap mty cls =
946+
run $ runReader nset $ runReader nmap $ goLambdaClauses mty cls
947+
948+
goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => Maybe Internal.Expression -> [Internal.LambdaClause] -> Sem r [CaseBranch]
949+
goLambdaClauses mty = \case
935950
cl@Internal.LambdaClause {..} : cls -> do
936-
(npat, nset, nmap) <- case _lambdaPatterns of
937-
p :| [] -> goPatternArgCase p
951+
let lambdaPats = filterLambdaPatternArgs mty _lambdaPatterns
952+
(npat, nset, nmap) <- case lambdaPats of
953+
[p] -> goPatternArgCase p
938954
_ -> do
939-
(npats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns)
955+
(npats, nset, nmap) <- goPatternArgsCase lambdaPats
940956
let npat =
941957
fmap
942958
( \pats ->
@@ -950,7 +966,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
950966
case npat of
951967
Nested pat [] -> do
952968
body <- withLocalNames nset nmap $ goExpression _lambdaBody
953-
brs <- goLambdaClauses cls
969+
brs <- goLambdaClauses mty cls
954970
return $
955971
CaseBranch
956972
{ _caseBranchPattern = pat,
@@ -961,7 +977,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
961977
let vname = defaultName (getLoc cl) (disambiguate (nset ^. nameSet) "v")
962978
nset' = over nameSet (HashSet.insert (vname ^. namePretty)) nset
963979
rhs <- withLocalNames nset' nmap $ goExpression _lambdaBody
964-
remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses cls
980+
remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses mty cls
965981
let brs' = goNestedBranches (getLoc vname) (ExprIden vname) rhs remainingBranches pat (nonEmpty' npats)
966982
return
967983
[ CaseBranch
@@ -1133,7 +1149,11 @@ goModule onlyTypes infoTable Internal.Module {..} =
11331149
case HashMap.lookup name (infoTable ^. Internal.infoConstructors) of
11341150
Just ctrInfo
11351151
| ctrInfo ^. Internal.constructorInfoRecord ->
1136-
Just (indName, goRecordFields (getArgtys ctrInfo) args)
1152+
case HashMap.lookup indName (infoTable ^. Internal.infoInductives) of
1153+
Just indInfo
1154+
| length (indInfo ^. Internal.inductiveInfoConstructors) == 1 ->
1155+
Just (indName, goRecordFields (getArgtys ctrInfo) args)
1156+
_ -> Nothing
11371157
where
11381158
indName = ctrInfo ^. Internal.constructorInfoInductive
11391159
_ -> Nothing

tests/positive/Isabelle/Program.juvix

+41
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,44 @@ funR4 : R -> R
154154
bf (b1 b2 : Bool) : Bool := not (b1 && b2);
155155

156156
nf (n1 n2 : Int) : Bool := n1 - n2 >= n1 || n2 <= n1 + n2;
157+
158+
-- Nested record patterns
159+
160+
type MessagePacket (MessageType : Type) : Type := mkMessagePacket {
161+
target : Nat;
162+
mailbox : Maybe Nat;
163+
message : MessageType;
164+
};
165+
166+
type EnvelopedMessage (MessageType : Type) : Type :=
167+
mkEnvelopedMessage {
168+
sender : Maybe Nat;
169+
packet : MessagePacket MessageType;
170+
};
171+
172+
type Timer (HandleType : Type): Type := mkTimer {
173+
time : Nat;
174+
handle : HandleType;
175+
};
176+
177+
type Trigger (MessageType : Type) (HandleType : Type) :=
178+
| MessageArrived { envelope : EnvelopedMessage MessageType; }
179+
| Elapsed { timers : List (Timer HandleType) };
180+
181+
getMessageFromTrigger : {M H : Type} -> Trigger M H -> Maybe M
182+
| (MessageArrived@{
183+
envelope := (mkEnvelopedMessage@{
184+
packet := (mkMessagePacket@{
185+
message := m })})})
186+
:= just m
187+
| _ := nothing;
188+
189+
190+
getMessageFromTrigger' {M H} (t : Trigger M H) : Maybe M :=
191+
case t of
192+
| (MessageArrived@{
193+
envelope := (mkEnvelopedMessage@{
194+
packet := (mkMessagePacket@{
195+
message := m })})})
196+
:= just m
197+
| _ := nothing;

tests/positive/Isabelle/isabelle/Program.thy

+58
Original file line numberDiff line numberDiff line change
@@ -240,4 +240,62 @@ fun bf :: "bool \<Rightarrow> bool \<Rightarrow> bool" where
240240
fun nf :: "int \<Rightarrow> int \<Rightarrow> bool" where
241241
"nf n1 n2 = (n1 - n2 \<ge> n1 \<or> n2 \<le> n1 + n2)"
242242

243+
(* Nested record patterns *)
244+
record 'MessageType MessagePacket =
245+
target :: nat
246+
mailbox :: "nat option"
247+
message :: 'MessageType
248+
249+
record 'MessageType EnvelopedMessage =
250+
sender :: "nat option"
251+
packet :: "'MessageType MessagePacket"
252+
253+
record 'HandleType Timer =
254+
time :: nat
255+
handle :: 'HandleType
256+
257+
datatype ('MessageType, 'HandleType) Trigger
258+
= MessageArrived "'MessageType EnvelopedMessage" |
259+
Elapsed "('HandleType Timer) list"
260+
261+
fun target :: "'MessageType MessagePacket \<Rightarrow> nat" where
262+
"target (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
263+
target'"
264+
265+
fun mailbox :: "'MessageType MessagePacket \<Rightarrow> nat option" where
266+
"mailbox (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
267+
mailbox'"
268+
269+
fun message :: "'MessageType MessagePacket \<Rightarrow> 'MessageType" where
270+
"message (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
271+
message'"
272+
273+
fun sender :: "'MessageType EnvelopedMessage \<Rightarrow> nat option" where
274+
"sender (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) = sender'"
275+
276+
fun packet :: "'MessageType EnvelopedMessage \<Rightarrow> 'MessageType MessagePacket" where
277+
"packet (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) = packet'"
278+
279+
fun time :: "'HandleType Timer \<Rightarrow> nat" where
280+
"time (| Timer.time = time', Timer.handle = handle' |) = time'"
281+
282+
fun handle :: "'HandleType Timer \<Rightarrow> 'HandleType" where
283+
"handle (| Timer.time = time', Timer.handle = handle' |) = handle'"
284+
285+
fun getMessageFromTrigger :: "('M, 'H) Trigger \<Rightarrow> 'M option" where
286+
"getMessageFromTrigger v_0 =
287+
(case (v_0) of
288+
(MessageArrived v') \<Rightarrow>
289+
(case (EnvelopedMessage.packet v') of
290+
(v'0) \<Rightarrow> Some (MessagePacket.message v'0)) |
291+
v'1 \<Rightarrow> None)"
292+
293+
fun getMessageFromTrigger' :: "('M, 'H) Trigger \<Rightarrow> 'M option" where
294+
"getMessageFromTrigger' t =
295+
(case t of
296+
(MessageArrived v') \<Rightarrow>
297+
(case (EnvelopedMessage.packet v') of
298+
(v'0) \<Rightarrow> Some (MessagePacket.message v'0)) |
299+
v'2 \<Rightarrow> None)"
300+
243301
end

0 commit comments

Comments
 (0)