@@ -2,7 +2,6 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where
2
2
3
3
import Data.HashMap.Strict qualified as HashMap
4
4
import Data.HashSet qualified as HashSet
5
- import Data.List.NonEmpty.Extra qualified as NonEmpty
6
5
import Data.Text qualified as T
7
6
import Data.Text qualified as Text
8
7
import Juvix.Compiler.Backend.Isabelle.Data.Result
@@ -95,19 +94,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
95
94
mkExprCase c@ Case {.. } = case _caseValue of
96
95
ExprIden v ->
97
96
case _caseBranches of
98
- CaseBranch {.. } :| [] ->
97
+ CaseBranch {.. } :| _ ->
99
98
case _caseBranchPattern of
100
99
PatVar v' -> substVar v' v _caseBranchBody
101
100
_ -> ExprCase c
102
- _ -> ExprCase c
103
101
ExprTuple (Tuple (ExprIden v :| [] )) ->
104
102
case _caseBranches of
105
- CaseBranch {.. } :| [] ->
103
+ CaseBranch {.. } :| _ ->
106
104
case _caseBranchPattern of
107
105
PatTuple (Tuple (PatVar v' :| [] )) -> substVar v' v _caseBranchBody
108
106
_ -> ExprCase c
109
- _ -> ExprCase c
110
- _ -> ExprCase c
107
+ _ ->
108
+ case _caseBranches of
109
+ br@ CaseBranch {.. } :| _ ->
110
+ case _caseBranchPattern of
111
+ PatVar _ ->
112
+ ExprCase
113
+ Case
114
+ { _caseValue = _caseValue,
115
+ _caseBranches = br :| []
116
+ }
117
+ PatTuple (Tuple (PatVar _ :| [] )) ->
118
+ ExprCase
119
+ Case
120
+ { _caseValue = _caseValue,
121
+ _caseBranches = br :| []
122
+ }
123
+ _ -> ExprCase c
111
124
112
125
goMutualBlock :: Internal. MutualBlock -> [Statement ]
113
126
goMutualBlock Internal. MutualBlock {.. } =
@@ -243,24 +256,25 @@ goModule onlyTypes infoTable Internal.Module {..} =
243
256
: goClauses cls
244
257
Nested pats npats ->
245
258
let rhs = goExpression'' nset' nmap' _lambdaBody
246
- argnames' = fmap getPatternArgName _lambdaPatterns
259
+ argnames' = fmap getPatternArgName lambdaPats
247
260
vnames =
248
- fmap
249
- ( \ (idx :: Int , mname ) ->
250
- maybe
251
- ( defaultName
252
- (getLoc cl)
253
- ( disambiguate
254
- (nset' ^. nameSet)
255
- (" v_" <> show idx)
256
- )
257
- )
258
- (overNameText (disambiguate (nset' ^. nameSet)))
259
- mname
260
- )
261
- (NonEmpty. zip (nonEmpty' [0 .. ]) argnames')
261
+ nonEmpty' $
262
+ fmap
263
+ ( \ (idx :: Int , mname ) ->
264
+ maybe
265
+ ( defaultName
266
+ (getLoc cl)
267
+ ( disambiguate
268
+ (nset' ^. nameSet)
269
+ (" v_" <> show idx)
270
+ )
271
+ )
272
+ (overNameText (disambiguate (nset' ^. nameSet)))
273
+ mname
274
+ )
275
+ (zip [0 .. ] argnames')
262
276
nset'' = foldl' (flip (over nameSet . HashSet. insert . (^. namePretty))) nset' vnames
263
- remainingBranches = goLambdaClauses'' nset'' nmap' cls
277
+ remainingBranches = goLambdaClauses'' nset'' nmap' ( Just ty) cls
264
278
valTuple = ExprTuple (Tuple (fmap ExprIden vnames))
265
279
patTuple = PatTuple (Tuple (nonEmpty' pats))
266
280
brs = goNestedBranches (getLoc cl) valTuple rhs remainingBranches patTuple (nonEmpty' npats)
@@ -275,7 +289,8 @@ goModule onlyTypes infoTable Internal.Module {..} =
275
289
}
276
290
]
277
291
where
278
- (npats0, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns))
292
+ lambdaPats = filterTypeArgs 0 ty (toList _lambdaPatterns)
293
+ (npats0, nset', nmap') = goPatternArgsTop lambdaPats
279
294
[] -> []
280
295
281
296
goNestedBranches :: Interval -> Expression -> Expression -> [CaseBranch ] -> Pattern -> NonEmpty (Expression , Nested Pattern ) -> NonEmpty CaseBranch
@@ -828,18 +843,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
828
843
| patsNum == 0 = goExpression (head _lambdaClauses ^. Internal. lambdaBody)
829
844
| otherwise = goLams vars
830
845
where
831
- patsNum =
832
- case _lambdaType of
833
- Just ty ->
834
- length
835
- . filterTypeArgs 0 ty
836
- . toList
837
- $ head _lambdaClauses ^. Internal. lambdaPatterns
838
- Nothing ->
839
- length
840
- . filter ((/= Internal. Implicit ) . (^. Internal. patternArgIsImplicit))
841
- . toList
842
- $ head _lambdaClauses ^. Internal. lambdaPatterns
846
+ patsNum = length $ filterLambdaPatternArgs _lambdaType $ head _lambdaClauses ^. Internal. lambdaPatterns
843
847
vars = map (\ i -> defaultName (getLoc lam) (" x" <> show i)) [0 .. patsNum - 1 ]
844
848
845
849
goLams :: [Name ] -> Sem r Expression
@@ -869,7 +873,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
869
873
Tuple
870
874
{ _tupleComponents = nonEmpty' vars'
871
875
}
872
- brs <- goLambdaClauses (toList _lambdaClauses)
876
+ brs <- goLambdaClauses _lambdaType (toList _lambdaClauses)
873
877
return $
874
878
mkExprCase
875
879
Case
@@ -926,17 +930,29 @@ goModule onlyTypes infoTable Internal.Module {..} =
926
930
Internal. CaseBranchRhsExpression e -> goExpression e
927
931
Internal. CaseBranchRhsIf {} -> error " unsupported: side conditions"
928
932
929
- goLambdaClauses'' :: NameSet -> NameMap -> [Internal. LambdaClause ] -> [CaseBranch ]
930
- goLambdaClauses'' nset nmap cls =
931
- run $ runReader nset $ runReader nmap $ goLambdaClauses cls
932
-
933
- goLambdaClauses :: forall r . (Members '[Reader NameSet , Reader NameMap ] r ) => [Internal. LambdaClause ] -> Sem r [CaseBranch ]
934
- goLambdaClauses = \ case
933
+ filterLambdaPatternArgs :: Maybe Internal. Expression -> NonEmpty Internal. PatternArg -> [Internal. PatternArg ]
934
+ filterLambdaPatternArgs mty cls = case mty of
935
+ Just ty ->
936
+ filterTypeArgs 0 ty
937
+ . toList
938
+ $ cls
939
+ Nothing ->
940
+ filter ((/= Internal. Implicit ) . (^. Internal. patternArgIsImplicit))
941
+ . toList
942
+ $ cls
943
+
944
+ goLambdaClauses'' :: NameSet -> NameMap -> Maybe Internal. Expression -> [Internal. LambdaClause ] -> [CaseBranch ]
945
+ goLambdaClauses'' nset nmap mty cls =
946
+ run $ runReader nset $ runReader nmap $ goLambdaClauses mty cls
947
+
948
+ goLambdaClauses :: forall r . (Members '[Reader NameSet , Reader NameMap ] r ) => Maybe Internal. Expression -> [Internal. LambdaClause ] -> Sem r [CaseBranch ]
949
+ goLambdaClauses mty = \ case
935
950
cl@ Internal. LambdaClause {.. } : cls -> do
936
- (npat, nset, nmap) <- case _lambdaPatterns of
937
- p :| [] -> goPatternArgCase p
951
+ let lambdaPats = filterLambdaPatternArgs mty _lambdaPatterns
952
+ (npat, nset, nmap) <- case lambdaPats of
953
+ [p] -> goPatternArgCase p
938
954
_ -> do
939
- (npats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns)
955
+ (npats, nset, nmap) <- goPatternArgsCase lambdaPats
940
956
let npat =
941
957
fmap
942
958
( \ pats ->
@@ -950,7 +966,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
950
966
case npat of
951
967
Nested pat [] -> do
952
968
body <- withLocalNames nset nmap $ goExpression _lambdaBody
953
- brs <- goLambdaClauses cls
969
+ brs <- goLambdaClauses mty cls
954
970
return $
955
971
CaseBranch
956
972
{ _caseBranchPattern = pat,
@@ -961,7 +977,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
961
977
let vname = defaultName (getLoc cl) (disambiguate (nset ^. nameSet) " v" )
962
978
nset' = over nameSet (HashSet. insert (vname ^. namePretty)) nset
963
979
rhs <- withLocalNames nset' nmap $ goExpression _lambdaBody
964
- remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses cls
980
+ remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses mty cls
965
981
let brs' = goNestedBranches (getLoc vname) (ExprIden vname) rhs remainingBranches pat (nonEmpty' npats)
966
982
return
967
983
[ CaseBranch
@@ -1133,7 +1149,11 @@ goModule onlyTypes infoTable Internal.Module {..} =
1133
1149
case HashMap. lookup name (infoTable ^. Internal. infoConstructors) of
1134
1150
Just ctrInfo
1135
1151
| ctrInfo ^. Internal. constructorInfoRecord ->
1136
- Just (indName, goRecordFields (getArgtys ctrInfo) args)
1152
+ case HashMap. lookup indName (infoTable ^. Internal. infoInductives) of
1153
+ Just indInfo
1154
+ | length (indInfo ^. Internal. inductiveInfoConstructors) == 1 ->
1155
+ Just (indName, goRecordFields (getArgtys ctrInfo) args)
1156
+ _ -> Nothing
1137
1157
where
1138
1158
indName = ctrInfo ^. Internal. constructorInfoInductive
1139
1159
_ -> Nothing
0 commit comments