From f86577761a713d9735493fd6d619ecbb84e947fb Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Tue, 6 Mar 2018 20:07:14 +0100 Subject: [PATCH] fix(compiler): Merge pattern bindings which appear in multiple alternatives --- check/tests/implicits.rs | 21 ++++ tests/pattern_match.rs | 37 ++++++- vm/src/core/mod.rs | 203 ++++++++++++++++++++++++++------------- 3 files changed, 190 insertions(+), 71 deletions(-) diff --git a/check/tests/implicits.rs b/check/tests/implicits.rs index 8f90413464..f4f37df9c4 100644 --- a/check/tests/implicits.rs +++ b/check/tests/implicits.rs @@ -36,6 +36,27 @@ f 42 assert_eq!(result, Ok(Type::int())); } +#[test] +fn single_implicit_implicit_arg() { + let _ = ::env_logger::try_init(); + let text = r#" +let f y : [Int] -> Int -> Int = y +/// @implicit +let i = 123 +f 42 +"#; + let (expr, result) = support::typecheck_expr(text); + + assert_eq!(result, Ok(Type::int())); + assert_eq!( + r#"let f ?implicit_arg y : [Int] -> Int -> Int = y +/// @implicit +let i = 123 +f ?i 42"#, + format::pretty_expr(text, &expr).trim() + ); +} + #[test] fn single_implicit_explicit_arg() { let _ = ::env_logger::try_init(); diff --git a/tests/pattern_match.rs b/tests/pattern_match.rs index 507d955eb2..4be80101d6 100644 --- a/tests/pattern_match.rs +++ b/tests/pattern_match.rs @@ -105,9 +105,9 @@ id (match Test 0 with 1i32 } -test_expr!{ nested_pattern, +test_expr!{ nested_pattern1, r#" -type Option a = | Some a | None +type Option a = | None | Some a match Some (Some 123) with | None -> 0 | Some None -> 1 @@ -118,7 +118,7 @@ match Some (Some 123) with test_expr!{ nested_pattern2, r#" -type Option a = | Some a | None +type Option a = | None | Some a match Some None with | None -> 0 | Some None -> 1 @@ -158,3 +158,34 @@ a #Int+ i #Int+ m "#, 20i32 } + +test_expr!{ match_with_id_binding_in_two_patterns_record, +r#" +type Option a = | None | Some a +match { _0 = 1, _1 = None } with +| { _0 = x, _1 = Some y } -> y +| { _0 = z, _1 = None } -> z +"#, +1 +} + +test_expr!{ match_with_id_binding_in_two_patterns_tuple, +r#" +type Option a = | None | Some a +match (1, None) with +| (x, Some y) -> y +| (z, None) -> z +"#, +1 +} + +test_expr!{ match_with_id_binding_in_two_patterns_variant, +r#" +type Option a = | None | Some a +match (Some 10, 1) with +| (Some y, 1) -> y +| (Some z, x) -> z +| (None, a) -> a +"#, +10 +} diff --git a/vm/src/core/mod.rs b/vm/src/core/mod.rs index b2dcd1df61..726e931ac8 100644 --- a/vm/src/core/mod.rs +++ b/vm/src/core/mod.rs @@ -1014,12 +1014,12 @@ enum CType { } use self::optimize::*; -struct ReplaceVariables<'a> { - replacements: HashMap, +struct ReplaceVariables<'a, 'b> { + replacements: &'b HashMap, allocator: &'a Allocator<'a>, } -impl<'a> Visitor<'a, 'a> for ReplaceVariables<'a> { +impl<'a, 'b> Visitor<'a, 'a> for ReplaceVariables<'a, 'b> { type Producer = SameLifetime<'a>; fn visit_expr(&mut self, expr: &'a Expr<'a>) -> Option<&'a Expr<'a>> { @@ -1041,6 +1041,22 @@ impl<'a> Visitor<'a, 'a> for ReplaceVariables<'a> { } } +fn replace_variables<'a, 'b>( + allocator: &'a Allocator<'a>, + replacements: &'b HashMap, + expr: &'a Expr<'a>, +) -> &'a Expr<'a> { + if replacements.is_empty() { + expr + } else { + ReplaceVariables { + replacements, + allocator, + }.visit_expr(expr) + .unwrap_or(expr) + } +} + /// `PatternTranslator` translated nested (AST) patterns into non-nested (core) patterns. /// /// It does this this by looking at each nested pattern as part of an `Equation` to be solved. @@ -1092,7 +1108,7 @@ impl<'a, 'e> PatternTranslator<'a, 'e> { .map(|equation| *equation.patterns.first().unwrap()) }; - let pattern = self.pattern_identifiers(first_iter()); + let (pattern, replacements) = self.pattern_identifiers(first_iter()); // Gather the inner patterns so we can prepend them to equations let temp = first_iter() @@ -1150,6 +1166,8 @@ impl<'a, 'e> PatternTranslator<'a, 'e> { let new_variables = self.insert_new_variables(&pattern, variables); let expr = self.translate(default, &new_variables, &new_equations); + let expr = replace_variables(&self.0.allocator, &replacements, expr); + Alternative { pattern: pattern, expr: expr, @@ -1204,7 +1222,7 @@ impl<'a, 'e> PatternTranslator<'a, 'e> { .into_iter() .map(|key| { let equations = &groups[key]; - let pattern = self.pattern_identifiers( + let (pattern, replacements) = self.pattern_identifiers( equations .iter() .map(|equation| *equation.patterns.first().unwrap()), @@ -1231,6 +1249,8 @@ impl<'a, 'e> PatternTranslator<'a, 'e> { let new_variables = self.insert_new_variables(&pattern, variables); let expr = self.translate(default, &new_variables, &new_equations); + let expr = replace_variables(&self.0.allocator, &replacements, expr); + Alternative { pattern: pattern, expr: expr, @@ -1289,11 +1309,13 @@ impl<'a, 'e> PatternTranslator<'a, 'e> { }) .collect::>(), ); - let pattern = self.pattern_identifiers( + let (pattern, replacements) = self.pattern_identifiers( equations .iter() .map(|equation| *equation.patterns.first().unwrap()), ); + let expr = replace_variables(&self.0.allocator, &replacements, expr); + let alt = Alternative { pattern: pattern, expr: expr, @@ -1305,11 +1327,11 @@ impl<'a, 'e> PatternTranslator<'a, 'e> { // EXPR // with `y`s replaced by `x` match (&alt.pattern, variables[0]) { (&Pattern::Ident(ref id), &Expr::Ident(ref expr_id, _)) => { - return ReplaceVariables { - replacements: collect![(id.name.clone(), expr_id.name.clone())], - allocator: &self.0.allocator, - }.visit_expr(expr) - .unwrap_or(expr); + return replace_variables( + &self.0.allocator, + &collect![(id.name.clone(), expr_id.name.clone())], + expr, + ); } _ => (), } @@ -1548,56 +1570,89 @@ impl<'a, 'e> PatternTranslator<'a, 'e> { } fn extract_ident(&self, index: usize, pattern: &ast::Pattern) -> TypedIdent { - match *pattern { - ast::Pattern::Ident(ref id) => id.clone(), - ast::Pattern::As(_, ref pat) => self.extract_ident(index, &pat.value), - _ => TypedIdent { - name: Symbol::from(format!("pattern_{}", index)), - typ: pattern.env_type_of(&self.0.env), - }, - } + get_ident(pattern).unwrap_or_else(|| TypedIdent { + name: Symbol::from(format!("pattern_{}", index)), + typ: pattern.env_type_of(&self.0.env), + }) } // Gather all the identifiers of top level pattern of each of the `patterns` and create a core // pattern. // Nested patterns are ignored here. - fn pattern_identifiers<'b, 'p: 'b, I>(&self, patterns: I) -> Pattern + fn pattern_identifiers<'b, 'p: 'b, I>(&self, patterns: I) -> (Pattern, HashMap) where I: IntoIterator>, { - let mut identifiers = Vec::new(); + let mut identifiers: Vec> = Vec::new(); let mut record_fields: Vec<(TypedIdent, _)> = Vec::new(); - let mut ident = None; + let mut core_pattern = None; + + // Since we merge all patterns that match on the same thing (variants with the same tag, + // any record or tuple ...), tuple patterns + // If a field has already been seen in an earlier pattern we must make sure + // that the variable bound in this pattern/field gets replaced with the + // symbol from the earlier pattern + let mut replacements = HashMap::default(); + + fn add_duplicate_ident( + replacements: &mut HashMap, + record_fields: &mut Vec<(TypedIdent, Option)>, + field: &Symbol, + pattern: Option<&SpannedPattern>, + ) -> bool { + match record_fields.iter().find(|id| id.0.name == *field).cloned() { + Some(earlier_var) => { + let duplicate = match pattern { + Some(ref pattern) => get_ident(&pattern.value).map(|id| id.name), + None => Some(field.clone()), + }; + if let Some(duplicate) = duplicate { + replacements.insert(duplicate, earlier_var.1.unwrap_or(earlier_var.0.name)); + } + true + } + None => false, + } + } for pattern in patterns { match *unwrap_as(&pattern.value) { ast::Pattern::Constructor(ref id, ref patterns) => { - identifiers.extend( - patterns - .iter() - .enumerate() - .map(|(i, pattern)| self.extract_ident(i, &pattern.value)), - ); - // Just extract the patterns of the first constructor found - return Pattern::Constructor(id.clone(), identifiers); + core_pattern = Some(Pattern::Constructor(id.clone(), Vec::new())); + + for (i, pattern) in patterns.iter().enumerate() { + match identifiers.get(i).map(|i| i.name.clone()) { + Some(earlier_var) => { + if let Some(duplicate) = get_ident(&pattern.value).map(|id| id.name) + { + replacements.insert(duplicate, earlier_var); + } + } + None => identifiers.push(self.extract_ident(i, &pattern.value)), + } + } } ast::Pattern::As(..) => unreachable!(), - ast::Pattern::Ident(ref id) => if ident.is_none() { - ident = Some(id.clone()) + ast::Pattern::Ident(ref id) => if core_pattern.is_none() { + core_pattern = Some(Pattern::Ident(id.clone())); }, ast::Pattern::Tuple { ref typ, ref elems } => { - record_fields.extend(elems.iter().zip(typ.row_iter()).enumerate().map( - |(i, (elem, field_type))| { - ( + for (i, (elem, field_type)) in elems.iter().zip(typ.row_iter()).enumerate() { + if !add_duplicate_ident( + &mut replacements, + &mut record_fields, + &field_type.name, + Some(elem), + ) { + record_fields.push(( TypedIdent { name: field_type.name.clone(), typ: field_type.typ.clone(), }, Some(self.extract_ident(i, &elem.value).name), - ) - }, - )); - break; + )); + } + } } // Records need to merge the bindings of each of the patterns as we want the core // `match` expression to just have one alternative @@ -1611,40 +1666,52 @@ impl<'a, 'e> PatternTranslator<'a, 'e> { ref typ, ref fields, .. - } => { - for (i, field) in fields.iter().enumerate() { - // Don't add one field twice - if record_fields.iter().all(|id| id.0.name != field.name.value) { - let x = field - .value - .as_ref() - .map(|pattern| self.extract_ident(i, &pattern.value).name); - let field_type = remove_aliases_cow(&self.0.env, typ) - .row_iter() - .find(|f| f.name.name_eq(&field.name.value)) - .map(|f| f.typ.clone()) - .unwrap_or_else(|| Type::hole()); - record_fields.push(( - TypedIdent { - name: field.name.value.clone(), - typ: field_type, - }, - x, - )); - } + } => for (i, field) in fields.iter().enumerate() { + if !add_duplicate_ident( + &mut replacements, + &mut record_fields, + &field.name.value, + field.value.as_ref(), + ) { + let x = field + .value + .as_ref() + .map(|pattern| self.extract_ident(i, &pattern.value).name); + let field_type = remove_aliases_cow(&self.0.env, typ) + .row_iter() + .find(|f| f.name.name_eq(&field.name.value)) + .map(|f| f.typ.clone()) + .unwrap_or_else(|| Type::hole()); + record_fields.push(( + TypedIdent { + name: field.name.value.clone(), + typ: field_type, + }, + x, + )); } - } + }, ast::Pattern::Literal(_) | ast::Pattern::Error => (), } } - if record_fields.is_empty() { - match ident { - Some(ident) => Pattern::Ident(ident), - None => Pattern::Record(record_fields), + let pattern = match core_pattern { + Some(mut p) => { + if let Pattern::Constructor(_, ref mut ids) = p { + *ids = identifiers + } + p } - } else { - Pattern::Record(record_fields) - } + None => Pattern::Record(record_fields), + }; + (pattern, replacements) + } +} + +fn get_ident(pattern: &ast::Pattern) -> Option> { + match *pattern { + ast::Pattern::Ident(ref id) => Some(id.clone()), + ast::Pattern::As(_, ref pat) => get_ident(&pat.value), + _ => None, } }