diff --git a/vm/src/core/dead_code.rs b/vm/src/core/dead_code.rs index ca73a58b20..6af6fc5849 100644 --- a/vm/src/core/dead_code.rs +++ b/vm/src/core/dead_code.rs @@ -10,22 +10,60 @@ use base::{ use crate::core::{ self, - optimize::{walk_expr, walk_expr_alloc, SameLifetime, Visitor}, - Allocator, CExpr, Expr, LetBinding, + optimize::{walk_expr, walk_expr_alloc, DifferentLifetime, SameLifetime, Visitor}, + Allocator, CExpr, Expr, LetBinding, Named, Pattern, }; +fn is_pure_simple(expr: CExpr) -> bool { + pub struct SimplePure(bool); + + impl<'l, 'expr> Visitor<'l, 'expr> for SimplePure { + type Producer = DifferentLifetime<'l, 'expr>; + + fn visit_expr(&mut self, expr: CExpr<'expr>) -> Option> { + if !self.0 { + return None; + } + match *expr { + Expr::Call(..) => { + self.0 = false; + None + } + Expr::Let(ref bind, expr) => { + match bind.expr { + // Creating a group of closures is always pure (though calling them may not be) + Named::Recursive(_) => (), + Named::Expr(expr) => { + self.visit_expr(expr); + } + } + self.visit_expr(expr) + } + _ => walk_expr_alloc(self, expr), + } + } + fn detach_allocator(&self) -> Option<&'l Allocator<'l>> { + None + } + } + + let mut visitor = SimplePure(true); + visitor.visit_expr(expr); + visitor.0 +} + pub fn dead_code_elimination<'a>(allocator: &'a Allocator<'a>, expr: CExpr<'a>) -> CExpr<'a> { - struct FreeVars<'a> { + struct DeadCodeEliminator<'a> { allocator: &'a Allocator<'a>, used_bindings: FnvSet<&'a SymbolRef>, } - impl FreeVars<'_> { + impl DeadCodeEliminator<'_> { fn is_used(&self, s: &Symbol) -> bool { self.used_bindings.contains(&**s) } } - impl<'e> Visitor<'e, 'e> for FreeVars<'e> { + impl<'e> Visitor<'e, 'e> for DeadCodeEliminator<'e> { type Producer = SameLifetime<'e>; fn visit_expr(&mut self, expr: CExpr<'e>) -> Option> { @@ -69,6 +107,26 @@ pub fn dead_code_elimination<'a>(allocator: &'a Allocator<'a>, expr: CExpr<'a>) &*self.allocator.arena.alloc(Expr::Let(bind, body)) }) } + + Expr::Match(scrutinee, alts) if alts.len() == 1 => match &alts[0].pattern { + Pattern::Record(fields) => { + if !is_pure_simple(scrutinee) + || fields + .iter() + .map(|(x, y)| y.as_ref().unwrap_or(&x.name)) + .any(|field_bind| self.is_used(&field_bind)) + { + walk_expr_alloc(self, expr) + } else { + Some( + self.visit_expr(alts[0].expr) + .unwrap_or_else(|| alts[0].expr), + ) + } + } + _ => walk_expr_alloc(self, expr), + }, + _ => walk_expr_alloc(self, expr), } } @@ -77,7 +135,7 @@ pub fn dead_code_elimination<'a>(allocator: &'a Allocator<'a>, expr: CExpr<'a>) } } - let mut free_vars = FreeVars { + let mut free_vars = DeadCodeEliminator { allocator, used_bindings: DepGraph::default().used_bindings(expr), }; @@ -224,4 +282,40 @@ mod tests { "#; check_optimization(initial_str, expected_str, dead_code_elimination); } + + #[test] + fn eliminate_redundant_match() { + let initial_str = r#" + match { x = 1 } with + | { x } -> 1 + end + "#; + let expected_str = r#" + 1 + "#; + check_optimization(initial_str, expected_str, dead_code_elimination); + } + + #[test] + fn dont_eliminate_used_match() { + let initial_str = r#" + rec let f y = y + in + let x = f 123 + in + match { x } with + | { x } -> x + end + "#; + let expected_str = r#" + rec let f y = y + in + let x = f 123 + in + match { x } with + | { x } -> x + end + "#; + check_optimization(initial_str, expected_str, dead_code_elimination); + } } diff --git a/vm/src/core/mod.rs b/vm/src/core/mod.rs index 414aabb209..2ebe3f10f2 100644 --- a/vm/src/core/mod.rs +++ b/vm/src/core/mod.rs @@ -1910,6 +1910,14 @@ mod tests { } fn expr_eq(map: &mut HashMap, l: &Expr, r: &Expr) -> bool { + let b = expr_eq_(map, l, r); + if !b { + eprintln!("{} != {}", l, r); + } + b + } + + fn expr_eq_(map: &mut HashMap, l: &Expr, r: &Expr) -> bool { match (l, r) { (&Expr::Match(_, l_alts), &Expr::Match(_, r_alts)) => { for (l, r) in l_alts.iter().zip(r_alts) { @@ -1952,7 +1960,20 @@ mod tests { (&Expr::Ident(ref l, _), &Expr::Ident(ref r, _)) => check(map, &l.name, &r.name), (&Expr::Let(ref lb, l), &Expr::Let(ref rb, r)) => { let b = match (&lb.expr, &rb.expr) { - (&Named::Expr(le), &Named::Expr(re)) => expr_eq(map, le, re), + (Named::Expr(le), Named::Expr(re)) => expr_eq(map, le, re), + (Named::Recursive(lc), Named::Recursive(rc)) => { + lc.len() == rc.len() + && lc.iter().zip(rc).all(|(lc, rc)| { + check(map, &lc.name.name, &rc.name.name) + && lc.args.len() == rc.args.len() + && lc + .args + .iter() + .zip(&rc.args) + .all(|(l, r)| check(map, &l.name, &r.name)) + && expr_eq(map, lc.expr, rc.expr) + }) + } _ => false, }; check(map, &lb.name.name, &rb.name.name) && b && expr_eq(map, l, r)