Skip to content

Commit 3f33a6d

Browse files
committed
simplify a_simplify_lang
1 parent 9c10fe4 commit 3f33a6d

File tree

2 files changed

+61
-196
lines changed

2 files changed

+61
-196
lines changed

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 59 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ pub fn simplify_program(mut program: Program) -> Result<SimpleProgram, String> {
234234
program.functions.remove(&name);
235235
}
236236

237-
let mut mutable_loop_counter = MutableLoopTransformCounter::default();
237+
let mut mutable_loop_counter = Counter::new();
238238
transform_mutable_in_loops_in_program(&mut program, &mut mutable_loop_counter);
239239

240240
let mut new_functions = BTreeMap::new();
@@ -297,36 +297,42 @@ pub fn simplify_program(mut program: Program) -> Result<SimpleProgram, String> {
297297
})
298298
}
299299

300-
#[derive(Debug, Clone, Default)]
301-
pub struct VectorLenTracker {
302-
vectors: BTreeMap<Var, VectorLenValue>,
300+
#[derive(Debug, Clone, PartialEq, Eq)]
301+
pub enum TreeVec<S> {
302+
Scalar(S),
303+
Vector(Vec<TreeVec<S>>),
303304
}
304305

305-
#[derive(Debug, Clone, PartialEq, Eq)]
306-
pub enum VectorLenValue {
307-
Scalar,
308-
Vector(Vec<VectorLenValue>),
306+
pub type VectorLenValue = TreeVec<()>;
307+
pub type VectorValue = TreeVec<Var>;
308+
309+
#[derive(Debug, Clone, Default)]
310+
pub struct TreeVecTracker<S> {
311+
vectors: BTreeMap<Var, TreeVec<S>>,
309312
}
310313

311-
impl VectorLenTracker {
312-
fn register(&mut self, var: &Var, value: VectorLenValue) {
314+
pub type VectorLenTracker = TreeVecTracker<()>;
315+
type VectorTracker = TreeVecTracker<Var>;
316+
317+
impl<S> TreeVecTracker<S> {
318+
fn register(&mut self, var: &Var, value: TreeVec<S>) {
313319
self.vectors.insert(var.clone(), value);
314320
}
315321

316322
fn is_vector(&self, var: &Var) -> bool {
317323
self.vectors.contains_key(var)
318324
}
319325

320-
pub fn get(&self, var: &Var) -> Option<&VectorLenValue> {
326+
pub fn get(&self, var: &Var) -> Option<&TreeVec<S>> {
321327
self.vectors.get(var)
322328
}
323329

324-
fn get_mut(&mut self, var: &Var) -> Option<&mut VectorLenValue> {
330+
fn get_mut(&mut self, var: &Var) -> Option<&mut TreeVec<S>> {
325331
self.vectors.get_mut(var)
326332
}
327333
}
328334

329-
impl VectorLenValue {
335+
impl<S> TreeVec<S> {
330336
pub fn push(&mut self, elem: Self) {
331337
match self {
332338
Self::Vector(v) => v.push(elem),
@@ -366,31 +372,25 @@ impl VectorLenValue {
366372
}
367373
}
368374

369-
pub fn navigate(&self, idx: &[F]) -> Option<&Self> {
370-
idx.iter().try_fold(self, |v, &i| v.get(i.to_usize()))
375+
pub fn navigate(&self, idx: &[usize]) -> Option<&Self> {
376+
idx.iter().try_fold(self, |v, &i| v.get(i))
371377
}
372378

373-
pub fn navigate_mut(&mut self, idx: &[F]) -> Option<&mut Self> {
374-
idx.iter().try_fold(self, |v, &i| v.get_mut(i.to_usize()))
379+
pub fn navigate_mut(&mut self, idx: &[usize]) -> Option<&mut Self> {
380+
idx.iter().try_fold(self, |v, &i| v.get_mut(i))
375381
}
376382
}
377383

378384
fn build_vector_len_value(elements: &[VecLiteral]) -> VectorLenValue {
379-
let mut vec_elements = Vec::new();
380-
381-
for elem in elements {
382-
let elem_len_value = build_vector_len_value_from_element(elem);
383-
vec_elements.push(elem_len_value);
384-
}
385-
386-
VectorLenValue::Vector(vec_elements)
387-
}
388-
389-
fn build_vector_len_value_from_element(element: &VecLiteral) -> VectorLenValue {
390-
match element {
391-
VecLiteral::Vec(inner) => build_vector_len_value(inner),
392-
VecLiteral::Expr(_) => VectorLenValue::Scalar,
393-
}
385+
VectorLenValue::Vector(
386+
elements
387+
.iter()
388+
.map(|elem| match elem {
389+
VecLiteral::Vec(inner) => build_vector_len_value(inner),
390+
VecLiteral::Expr(_) => VectorLenValue::Scalar(()),
391+
})
392+
.collect(),
393+
)
394394
}
395395

396396
fn compile_time_transform_in_program(
@@ -578,10 +578,17 @@ fn compile_time_transform_in_lines(
578578
element,
579579
..
580580
} => {
581-
let Some(const_indices) = indices.iter().map(|idx| idx.as_scalar()).collect::<Option<Vec<_>>>() else {
581+
let Some(const_indices) = indices
582+
.iter()
583+
.map(|idx| idx.as_scalar().map(|f| f.to_usize()))
584+
.collect::<Option<Vec<_>>>()
585+
else {
582586
return Err("push with non-constant indices".to_string());
583587
};
584-
let new_element = build_vector_len_value_from_element(element);
588+
let new_element = match element {
589+
VecLiteral::Vec(inner) => build_vector_len_value(inner),
590+
VecLiteral::Expr(_) => VectorLenValue::Scalar(()),
591+
};
585592
let vector_value = vector_len_tracker
586593
.get_mut(vector)
587594
.ok_or_else(|| "pushing to undeclared vector".to_string())?;
@@ -603,7 +610,11 @@ fn compile_time_transform_in_lines(
603610
indices,
604611
location,
605612
} => {
606-
let Some(const_indices) = indices.iter().map(|idx| idx.as_scalar()).collect::<Option<Vec<_>>>() else {
613+
let Some(const_indices) = indices
614+
.iter()
615+
.map(|idx| idx.as_scalar().map(|f| f.to_usize()))
616+
.collect::<Option<Vec<_>>>()
617+
else {
607618
return Err(format!("line {}: pop with non-constant indices", location));
608619
};
609620
let vector_value = vector_len_tracker
@@ -1066,20 +1077,6 @@ fn substitute_const_vars_in_expr(expr: &mut Expression, const_var_exprs: &BTreeM
10661077
// }
10671078
// x = x_buff[size];
10681079

1069-
/// Counter for generating unique variable names in the mutable loop transformation
1070-
#[derive(Default)]
1071-
struct MutableLoopTransformCounter {
1072-
counter: usize,
1073-
}
1074-
1075-
impl MutableLoopTransformCounter {
1076-
fn next_suffix(&mut self) -> usize {
1077-
let c = self.counter;
1078-
self.counter += 1;
1079-
c
1080-
}
1081-
}
1082-
10831080
/// Finds mutable variables that are:
10841081
/// 1. Defined OUTSIDE this block (external)
10851082
/// 2. Re-assigned INSIDE this block
@@ -1141,7 +1138,7 @@ fn find_assigned_external_vars_helper(
11411138
}
11421139
}
11431140

1144-
fn transform_mutable_in_loops_in_program(program: &mut Program, counter: &mut MutableLoopTransformCounter) {
1141+
fn transform_mutable_in_loops_in_program(program: &mut Program, counter: &mut Counter) {
11451142
for func in program.functions.values_mut() {
11461143
transform_mutable_in_loops_in_lines(&mut func.body, &program.const_arrays, counter);
11471144
}
@@ -1150,7 +1147,7 @@ fn transform_mutable_in_loops_in_program(program: &mut Program, counter: &mut Mu
11501147
fn transform_mutable_in_loops_in_lines(
11511148
lines: &mut Vec<Line>,
11521149
const_arrays: &BTreeMap<String, ConstArrayValue>,
1153-
counter: &mut MutableLoopTransformCounter,
1150+
counter: &mut Counter,
11541151
) {
11551152
let mut i = 0;
11561153
while i < lines.len() {
@@ -1176,7 +1173,7 @@ fn transform_mutable_in_loops_in_lines(
11761173
continue;
11771174
}
11781175

1179-
let suffix = counter.next_suffix();
1176+
let suffix = counter.get_next();
11801177

11811178
// Generate the transformed code
11821179
let mut new_lines = Vec::new();
@@ -1267,7 +1264,13 @@ fn transform_mutable_in_loops_in_lines(
12671264
});
12681265

12691266
// Replace all references to var with body_name in the original body
1270-
replace_var_in_lines(body, var, body_name);
1267+
transform_vars_in_lines(body, &|v: &Var| {
1268+
if v == var {
1269+
VarTransform::Rename(body_name.clone())
1270+
} else {
1271+
VarTransform::Keep
1272+
}
1273+
});
12711274
}
12721275

12731276
// Add the original body (now modified to use body_vars)
@@ -1342,65 +1345,6 @@ fn transform_mutable_in_loops_in_lines(
13421345
}
13431346
}
13441347

1345-
/// Replaces all occurrences of a variable with another variable in a list of lines.
1346-
/// This is used to replace references to mutable variables with their body counterparts.
1347-
fn replace_var_in_lines(lines: &mut [Line], old_var: &Var, new_var: &Var) {
1348-
for line in lines {
1349-
match line {
1350-
Line::ForwardDeclaration { var, .. } => {
1351-
if var == old_var {
1352-
*var = new_var.clone();
1353-
}
1354-
}
1355-
Line::Statement { targets, .. } => {
1356-
for target in targets {
1357-
match target {
1358-
AssignmentTarget::Var { var, .. } => {
1359-
if var == old_var {
1360-
*var = new_var.clone();
1361-
}
1362-
}
1363-
AssignmentTarget::ArrayAccess { array, index } => {
1364-
if array == old_var {
1365-
*array = new_var.clone();
1366-
}
1367-
replace_var_in_expr(index, old_var, new_var);
1368-
}
1369-
}
1370-
}
1371-
}
1372-
_ => {}
1373-
}
1374-
for expr in line.expressions_mut() {
1375-
replace_var_in_expr(expr, old_var, new_var);
1376-
}
1377-
for block in line.nested_blocks_mut() {
1378-
replace_var_in_lines(block, old_var, new_var);
1379-
}
1380-
}
1381-
}
1382-
1383-
fn replace_var_in_expr(expr: &mut Expression, old_var: &Var, new_var: &Var) {
1384-
match expr {
1385-
Expression::Value(simple_expr) => {
1386-
if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = simple_expr
1387-
&& var == old_var
1388-
{
1389-
*var = new_var.clone();
1390-
}
1391-
}
1392-
Expression::ArrayAccess { array, .. } => {
1393-
if array == old_var {
1394-
*array = new_var.clone();
1395-
}
1396-
}
1397-
_ => {}
1398-
}
1399-
for inner_expr in expr.inner_exprs_mut() {
1400-
replace_var_in_expr(inner_expr, old_var, new_var);
1401-
}
1402-
}
1403-
14041348
fn check_function_always_returns(func: &SimpleFunction) -> Result<(), String> {
14051349
check_block_always_returns(&func.name, &func.instructions)
14061350
}
@@ -1728,8 +1672,7 @@ struct Counters {
17281672

17291673
impl Counters {
17301674
fn aux_var(&mut self) -> Var {
1731-
let var = format!("@aux_var_{}", self.aux_vars.get_next());
1732-
var
1675+
format!("@aux_var_{}", self.aux_vars.get_next())
17331676
}
17341677
}
17351678

@@ -1902,85 +1845,6 @@ impl MutableVarTracker {
19021845
}
19031846
}
19041847

1905-
/// Compile-time vector. Scalars hold variable names; Vectors hold nested values.
1906-
#[derive(Debug, Clone, PartialEq, Eq)]
1907-
pub enum VectorValue {
1908-
Scalar { var: Var },
1909-
Vector(Vec<VectorValue>),
1910-
}
1911-
1912-
impl VectorValue {
1913-
pub fn len(&self) -> usize {
1914-
match self {
1915-
Self::Vector(v) => v.len(),
1916-
_ => panic!("len on scalar"),
1917-
}
1918-
}
1919-
1920-
pub fn is_vector(&self) -> bool {
1921-
matches!(self, Self::Vector(_))
1922-
}
1923-
1924-
fn get(&self, i: usize) -> Option<&Self> {
1925-
match self {
1926-
Self::Vector(v) => v.get(i),
1927-
_ => None,
1928-
}
1929-
}
1930-
1931-
fn get_mut(&mut self, i: usize) -> Option<&mut Self> {
1932-
match self {
1933-
Self::Vector(v) => v.get_mut(i),
1934-
_ => None,
1935-
}
1936-
}
1937-
1938-
pub fn navigate(&self, idx: &[usize]) -> Option<&Self> {
1939-
idx.iter().try_fold(self, |v, &i| v.get(i))
1940-
}
1941-
1942-
pub fn navigate_mut(&mut self, idx: &[usize]) -> Option<&mut Self> {
1943-
idx.iter().try_fold(self, |v, &i| v.get_mut(i))
1944-
}
1945-
1946-
pub fn push(&mut self, elem: Self) {
1947-
match self {
1948-
Self::Vector(v) => v.push(elem),
1949-
_ => panic!("push on scalar"),
1950-
}
1951-
}
1952-
1953-
pub fn pop(&mut self) -> Option<Self> {
1954-
match self {
1955-
Self::Vector(v) => v.pop(),
1956-
_ => panic!("pop on scalar"),
1957-
}
1958-
}
1959-
}
1960-
1961-
#[derive(Debug, Clone, Default)]
1962-
struct VectorTracker {
1963-
vectors: BTreeMap<Var, VectorValue>,
1964-
}
1965-
1966-
impl VectorTracker {
1967-
fn register(&mut self, var: &Var, value: VectorValue) {
1968-
self.vectors.insert(var.clone(), value);
1969-
}
1970-
1971-
fn is_vector(&self, var: &Var) -> bool {
1972-
self.vectors.contains_key(var)
1973-
}
1974-
1975-
fn get(&self, var: &Var) -> Option<&VectorValue> {
1976-
self.vectors.get(var)
1977-
}
1978-
1979-
fn get_mut(&mut self, var: &Var) -> Option<&mut VectorValue> {
1980-
self.vectors.get_mut(var)
1981-
}
1982-
}
1983-
19841848
#[derive(Debug, Clone, Default)]
19851849
pub struct ConstMalloc {
19861850
counter: usize,
@@ -2038,7 +1902,7 @@ fn build_vector_value_from_element(
20381902
let aux_var = state.counters.aux_var();
20391903
let simplified_value = simplify_expr(ctx, state, const_malloc, expr, lines)?;
20401904
lines.push(SimpleLine::equality(aux_var.clone(), simplified_value));
2041-
Ok(VectorValue::Scalar { var: aux_var })
1905+
Ok(VectorValue::Scalar(aux_var))
20421906
}
20431907
}
20441908
}
@@ -3000,7 +2864,7 @@ fn simplify_expr(
30002864
.ok_or_else(|| format!("Vector index out of bounds: {:?}", const_indices))?;
30012865

30022866
match element {
3003-
VectorValue::Scalar { var } => {
2867+
VectorValue::Scalar(var) => {
30042868
// Return memory reference to this variable
30052869
return Ok(SimpleExpr::Memory(VarOrConstMallocAccess::Var(var.clone())));
30062870
}

crates/lean_compiler/src/lang.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,8 @@ impl Expression {
429429
return Some(F::from_usize(target.len()));
430430
}
431431
if let Some(arr) = vector_len.get(array) {
432-
let target = arr.navigate(&idx)?;
432+
let usize_idx: Vec<usize> = idx.iter().map(|f| f.to_usize()).collect();
433+
let target = arr.navigate(&usize_idx)?;
433434
return Some(F::from_usize(target.len()));
434435
}
435436
return None;

0 commit comments

Comments
 (0)