@@ -234,7 +234,7 @@ pub fn simplify_program(mut program: Program) -> Result<SimpleProgram, String> {
234234 program. functions . remove ( & name) ;
235235 }
236236
237- let mut mutable_loop_counter = MutableLoopTransformCounter :: default ( ) ;
237+ let mut mutable_loop_counter = Counter :: new ( ) ;
238238 transform_mutable_in_loops_in_program ( & mut program, & mut mutable_loop_counter) ;
239239
240240 let mut new_functions = BTreeMap :: new ( ) ;
@@ -297,36 +297,42 @@ pub fn simplify_program(mut program: Program) -> Result<SimpleProgram, String> {
297297 } )
298298}
299299
300- #[ derive( Debug , Clone , Default ) ]
301- pub struct VectorLenTracker {
302- vectors : BTreeMap < Var , VectorLenValue > ,
300+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
301+ pub enum TreeVec < S > {
302+ Scalar ( S ) ,
303+ Vector ( Vec < TreeVec < S > > ) ,
303304}
304305
305- #[ derive( Debug , Clone , PartialEq , Eq ) ]
306- pub enum VectorLenValue {
307- Scalar ,
308- Vector ( Vec < VectorLenValue > ) ,
306+ pub type VectorLenValue = TreeVec < ( ) > ;
307+ pub type VectorValue = TreeVec < Var > ;
308+
309+ #[ derive( Debug , Clone , Default ) ]
310+ pub struct TreeVecTracker < S > {
311+ vectors : BTreeMap < Var , TreeVec < S > > ,
309312}
310313
311- impl VectorLenTracker {
312- fn register ( & mut self , var : & Var , value : VectorLenValue ) {
314+ pub type VectorLenTracker = TreeVecTracker < ( ) > ;
315+ type VectorTracker = TreeVecTracker < Var > ;
316+
317+ impl < S > TreeVecTracker < S > {
318+ fn register ( & mut self , var : & Var , value : TreeVec < S > ) {
313319 self . vectors . insert ( var. clone ( ) , value) ;
314320 }
315321
316322 fn is_vector ( & self , var : & Var ) -> bool {
317323 self . vectors . contains_key ( var)
318324 }
319325
320- pub fn get ( & self , var : & Var ) -> Option < & VectorLenValue > {
326+ pub fn get ( & self , var : & Var ) -> Option < & TreeVec < S > > {
321327 self . vectors . get ( var)
322328 }
323329
324- fn get_mut ( & mut self , var : & Var ) -> Option < & mut VectorLenValue > {
330+ fn get_mut ( & mut self , var : & Var ) -> Option < & mut TreeVec < S > > {
325331 self . vectors . get_mut ( var)
326332 }
327333}
328334
329- impl VectorLenValue {
335+ impl < S > TreeVec < S > {
330336 pub fn push ( & mut self , elem : Self ) {
331337 match self {
332338 Self :: Vector ( v) => v. push ( elem) ,
@@ -366,31 +372,25 @@ impl VectorLenValue {
366372 }
367373 }
368374
369- pub fn navigate ( & self , idx : & [ F ] ) -> Option < & Self > {
370- idx. iter ( ) . try_fold ( self , |v, & i| v. get ( i. to_usize ( ) ) )
375+ pub fn navigate ( & self , idx : & [ usize ] ) -> Option < & Self > {
376+ idx. iter ( ) . try_fold ( self , |v, & i| v. get ( i) )
371377 }
372378
373- pub fn navigate_mut ( & mut self , idx : & [ F ] ) -> Option < & mut Self > {
374- idx. iter ( ) . try_fold ( self , |v, & i| v. get_mut ( i. to_usize ( ) ) )
379+ pub fn navigate_mut ( & mut self , idx : & [ usize ] ) -> Option < & mut Self > {
380+ idx. iter ( ) . try_fold ( self , |v, & i| v. get_mut ( i) )
375381 }
376382}
377383
378384fn build_vector_len_value ( elements : & [ VecLiteral ] ) -> VectorLenValue {
379- let mut vec_elements = Vec :: new ( ) ;
380-
381- for elem in elements {
382- let elem_len_value = build_vector_len_value_from_element ( elem) ;
383- vec_elements. push ( elem_len_value) ;
384- }
385-
386- VectorLenValue :: Vector ( vec_elements)
387- }
388-
389- fn build_vector_len_value_from_element ( element : & VecLiteral ) -> VectorLenValue {
390- match element {
391- VecLiteral :: Vec ( inner) => build_vector_len_value ( inner) ,
392- VecLiteral :: Expr ( _) => VectorLenValue :: Scalar ,
393- }
385+ VectorLenValue :: Vector (
386+ elements
387+ . iter ( )
388+ . map ( |elem| match elem {
389+ VecLiteral :: Vec ( inner) => build_vector_len_value ( inner) ,
390+ VecLiteral :: Expr ( _) => VectorLenValue :: Scalar ( ( ) ) ,
391+ } )
392+ . collect ( ) ,
393+ )
394394}
395395
396396fn compile_time_transform_in_program (
@@ -578,10 +578,17 @@ fn compile_time_transform_in_lines(
578578 element,
579579 ..
580580 } => {
581- let Some ( const_indices) = indices. iter ( ) . map ( |idx| idx. as_scalar ( ) ) . collect :: < Option < Vec < _ > > > ( ) else {
581+ let Some ( const_indices) = indices
582+ . iter ( )
583+ . map ( |idx| idx. as_scalar ( ) . map ( |f| f. to_usize ( ) ) )
584+ . collect :: < Option < Vec < _ > > > ( )
585+ else {
582586 return Err ( "push with non-constant indices" . to_string ( ) ) ;
583587 } ;
584- let new_element = build_vector_len_value_from_element ( element) ;
588+ let new_element = match element {
589+ VecLiteral :: Vec ( inner) => build_vector_len_value ( inner) ,
590+ VecLiteral :: Expr ( _) => VectorLenValue :: Scalar ( ( ) ) ,
591+ } ;
585592 let vector_value = vector_len_tracker
586593 . get_mut ( vector)
587594 . ok_or_else ( || "pushing to undeclared vector" . to_string ( ) ) ?;
@@ -603,7 +610,11 @@ fn compile_time_transform_in_lines(
603610 indices,
604611 location,
605612 } => {
606- let Some ( const_indices) = indices. iter ( ) . map ( |idx| idx. as_scalar ( ) ) . collect :: < Option < Vec < _ > > > ( ) else {
613+ let Some ( const_indices) = indices
614+ . iter ( )
615+ . map ( |idx| idx. as_scalar ( ) . map ( |f| f. to_usize ( ) ) )
616+ . collect :: < Option < Vec < _ > > > ( )
617+ else {
607618 return Err ( format ! ( "line {}: pop with non-constant indices" , location) ) ;
608619 } ;
609620 let vector_value = vector_len_tracker
@@ -1066,20 +1077,6 @@ fn substitute_const_vars_in_expr(expr: &mut Expression, const_var_exprs: &BTreeM
10661077// }
10671078// x = x_buff[size];
10681079
1069- /// Counter for generating unique variable names in the mutable loop transformation
1070- #[ derive( Default ) ]
1071- struct MutableLoopTransformCounter {
1072- counter : usize ,
1073- }
1074-
1075- impl MutableLoopTransformCounter {
1076- fn next_suffix ( & mut self ) -> usize {
1077- let c = self . counter ;
1078- self . counter += 1 ;
1079- c
1080- }
1081- }
1082-
10831080/// Finds mutable variables that are:
10841081/// 1. Defined OUTSIDE this block (external)
10851082/// 2. Re-assigned INSIDE this block
@@ -1141,7 +1138,7 @@ fn find_assigned_external_vars_helper(
11411138 }
11421139}
11431140
1144- fn transform_mutable_in_loops_in_program ( program : & mut Program , counter : & mut MutableLoopTransformCounter ) {
1141+ fn transform_mutable_in_loops_in_program ( program : & mut Program , counter : & mut Counter ) {
11451142 for func in program. functions . values_mut ( ) {
11461143 transform_mutable_in_loops_in_lines ( & mut func. body , & program. const_arrays , counter) ;
11471144 }
@@ -1150,7 +1147,7 @@ fn transform_mutable_in_loops_in_program(program: &mut Program, counter: &mut Mu
11501147fn transform_mutable_in_loops_in_lines (
11511148 lines : & mut Vec < Line > ,
11521149 const_arrays : & BTreeMap < String , ConstArrayValue > ,
1153- counter : & mut MutableLoopTransformCounter ,
1150+ counter : & mut Counter ,
11541151) {
11551152 let mut i = 0 ;
11561153 while i < lines. len ( ) {
@@ -1176,7 +1173,7 @@ fn transform_mutable_in_loops_in_lines(
11761173 continue ;
11771174 }
11781175
1179- let suffix = counter. next_suffix ( ) ;
1176+ let suffix = counter. get_next ( ) ;
11801177
11811178 // Generate the transformed code
11821179 let mut new_lines = Vec :: new ( ) ;
@@ -1267,7 +1264,13 @@ fn transform_mutable_in_loops_in_lines(
12671264 } ) ;
12681265
12691266 // Replace all references to var with body_name in the original body
1270- replace_var_in_lines ( body, var, body_name) ;
1267+ transform_vars_in_lines ( body, & |v : & Var | {
1268+ if v == var {
1269+ VarTransform :: Rename ( body_name. clone ( ) )
1270+ } else {
1271+ VarTransform :: Keep
1272+ }
1273+ } ) ;
12711274 }
12721275
12731276 // Add the original body (now modified to use body_vars)
@@ -1342,65 +1345,6 @@ fn transform_mutable_in_loops_in_lines(
13421345 }
13431346}
13441347
1345- /// Replaces all occurrences of a variable with another variable in a list of lines.
1346- /// This is used to replace references to mutable variables with their body counterparts.
1347- fn replace_var_in_lines ( lines : & mut [ Line ] , old_var : & Var , new_var : & Var ) {
1348- for line in lines {
1349- match line {
1350- Line :: ForwardDeclaration { var, .. } => {
1351- if var == old_var {
1352- * var = new_var. clone ( ) ;
1353- }
1354- }
1355- Line :: Statement { targets, .. } => {
1356- for target in targets {
1357- match target {
1358- AssignmentTarget :: Var { var, .. } => {
1359- if var == old_var {
1360- * var = new_var. clone ( ) ;
1361- }
1362- }
1363- AssignmentTarget :: ArrayAccess { array, index } => {
1364- if array == old_var {
1365- * array = new_var. clone ( ) ;
1366- }
1367- replace_var_in_expr ( index, old_var, new_var) ;
1368- }
1369- }
1370- }
1371- }
1372- _ => { }
1373- }
1374- for expr in line. expressions_mut ( ) {
1375- replace_var_in_expr ( expr, old_var, new_var) ;
1376- }
1377- for block in line. nested_blocks_mut ( ) {
1378- replace_var_in_lines ( block, old_var, new_var) ;
1379- }
1380- }
1381- }
1382-
1383- fn replace_var_in_expr ( expr : & mut Expression , old_var : & Var , new_var : & Var ) {
1384- match expr {
1385- Expression :: Value ( simple_expr) => {
1386- if let SimpleExpr :: Memory ( VarOrConstMallocAccess :: Var ( var) ) = simple_expr
1387- && var == old_var
1388- {
1389- * var = new_var. clone ( ) ;
1390- }
1391- }
1392- Expression :: ArrayAccess { array, .. } => {
1393- if array == old_var {
1394- * array = new_var. clone ( ) ;
1395- }
1396- }
1397- _ => { }
1398- }
1399- for inner_expr in expr. inner_exprs_mut ( ) {
1400- replace_var_in_expr ( inner_expr, old_var, new_var) ;
1401- }
1402- }
1403-
14041348fn check_function_always_returns ( func : & SimpleFunction ) -> Result < ( ) , String > {
14051349 check_block_always_returns ( & func. name , & func. instructions )
14061350}
@@ -1728,8 +1672,7 @@ struct Counters {
17281672
17291673impl Counters {
17301674 fn aux_var ( & mut self ) -> Var {
1731- let var = format ! ( "@aux_var_{}" , self . aux_vars. get_next( ) ) ;
1732- var
1675+ format ! ( "@aux_var_{}" , self . aux_vars. get_next( ) )
17331676 }
17341677}
17351678
@@ -1902,85 +1845,6 @@ impl MutableVarTracker {
19021845 }
19031846}
19041847
1905- /// Compile-time vector. Scalars hold variable names; Vectors hold nested values.
1906- #[ derive( Debug , Clone , PartialEq , Eq ) ]
1907- pub enum VectorValue {
1908- Scalar { var : Var } ,
1909- Vector ( Vec < VectorValue > ) ,
1910- }
1911-
1912- impl VectorValue {
1913- pub fn len ( & self ) -> usize {
1914- match self {
1915- Self :: Vector ( v) => v. len ( ) ,
1916- _ => panic ! ( "len on scalar" ) ,
1917- }
1918- }
1919-
1920- pub fn is_vector ( & self ) -> bool {
1921- matches ! ( self , Self :: Vector ( _) )
1922- }
1923-
1924- fn get ( & self , i : usize ) -> Option < & Self > {
1925- match self {
1926- Self :: Vector ( v) => v. get ( i) ,
1927- _ => None ,
1928- }
1929- }
1930-
1931- fn get_mut ( & mut self , i : usize ) -> Option < & mut Self > {
1932- match self {
1933- Self :: Vector ( v) => v. get_mut ( i) ,
1934- _ => None ,
1935- }
1936- }
1937-
1938- pub fn navigate ( & self , idx : & [ usize ] ) -> Option < & Self > {
1939- idx. iter ( ) . try_fold ( self , |v, & i| v. get ( i) )
1940- }
1941-
1942- pub fn navigate_mut ( & mut self , idx : & [ usize ] ) -> Option < & mut Self > {
1943- idx. iter ( ) . try_fold ( self , |v, & i| v. get_mut ( i) )
1944- }
1945-
1946- pub fn push ( & mut self , elem : Self ) {
1947- match self {
1948- Self :: Vector ( v) => v. push ( elem) ,
1949- _ => panic ! ( "push on scalar" ) ,
1950- }
1951- }
1952-
1953- pub fn pop ( & mut self ) -> Option < Self > {
1954- match self {
1955- Self :: Vector ( v) => v. pop ( ) ,
1956- _ => panic ! ( "pop on scalar" ) ,
1957- }
1958- }
1959- }
1960-
1961- #[ derive( Debug , Clone , Default ) ]
1962- struct VectorTracker {
1963- vectors : BTreeMap < Var , VectorValue > ,
1964- }
1965-
1966- impl VectorTracker {
1967- fn register ( & mut self , var : & Var , value : VectorValue ) {
1968- self . vectors . insert ( var. clone ( ) , value) ;
1969- }
1970-
1971- fn is_vector ( & self , var : & Var ) -> bool {
1972- self . vectors . contains_key ( var)
1973- }
1974-
1975- fn get ( & self , var : & Var ) -> Option < & VectorValue > {
1976- self . vectors . get ( var)
1977- }
1978-
1979- fn get_mut ( & mut self , var : & Var ) -> Option < & mut VectorValue > {
1980- self . vectors . get_mut ( var)
1981- }
1982- }
1983-
19841848#[ derive( Debug , Clone , Default ) ]
19851849pub struct ConstMalloc {
19861850 counter : usize ,
@@ -2038,7 +1902,7 @@ fn build_vector_value_from_element(
20381902 let aux_var = state. counters . aux_var ( ) ;
20391903 let simplified_value = simplify_expr ( ctx, state, const_malloc, expr, lines) ?;
20401904 lines. push ( SimpleLine :: equality ( aux_var. clone ( ) , simplified_value) ) ;
2041- Ok ( VectorValue :: Scalar { var : aux_var } )
1905+ Ok ( VectorValue :: Scalar ( aux_var) )
20421906 }
20431907 }
20441908}
@@ -3000,7 +2864,7 @@ fn simplify_expr(
30002864 . ok_or_else ( || format ! ( "Vector index out of bounds: {:?}" , const_indices) ) ?;
30012865
30022866 match element {
3003- VectorValue :: Scalar { var } => {
2867+ VectorValue :: Scalar ( var) => {
30042868 // Return memory reference to this variable
30052869 return Ok ( SimpleExpr :: Memory ( VarOrConstMallocAccess :: Var ( var. clone ( ) ) ) ) ;
30062870 }
0 commit comments