Skip to content

Commit

Permalink
fix(compiler): Merge pattern bindings which appear in multiple
Browse files Browse the repository at this point in the history
alternatives
  • Loading branch information
Marwes committed Mar 10, 2018
1 parent 989ca97 commit f865777
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 71 deletions.
21 changes: 21 additions & 0 deletions check/tests/implicits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ f 42
assert_eq!(result, Ok(Type::int()));
}

#[test]
fn single_implicit_implicit_arg() {
let _ = ::env_logger::try_init();
let text = r#"
let f y : [Int] -> Int -> Int = y
/// @implicit
let i = 123
f 42
"#;
let (expr, result) = support::typecheck_expr(text);

assert_eq!(result, Ok(Type::int()));
assert_eq!(
r#"let f ?implicit_arg y : [Int] -> Int -> Int = y
/// @implicit
let i = 123
f ?i 42"#,
format::pretty_expr(text, &expr).trim()
);
}

#[test]
fn single_implicit_explicit_arg() {
let _ = ::env_logger::try_init();
Expand Down
37 changes: 34 additions & 3 deletions tests/pattern_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ id (match Test 0 with
1i32
}

test_expr!{ nested_pattern,
test_expr!{ nested_pattern1,
r#"
type Option a = | Some a | None
type Option a = | None | Some a
match Some (Some 123) with
| None -> 0
| Some None -> 1
Expand All @@ -118,7 +118,7 @@ match Some (Some 123) with

test_expr!{ nested_pattern2,
r#"
type Option a = | Some a | None
type Option a = | None | Some a
match Some None with
| None -> 0
| Some None -> 1
Expand Down Expand Up @@ -158,3 +158,34 @@ a #Int+ i #Int+ m
"#,
20i32
}

test_expr!{ match_with_id_binding_in_two_patterns_record,
r#"
type Option a = | None | Some a
match { _0 = 1, _1 = None } with
| { _0 = x, _1 = Some y } -> y
| { _0 = z, _1 = None } -> z
"#,
1
}

test_expr!{ match_with_id_binding_in_two_patterns_tuple,
r#"
type Option a = | None | Some a
match (1, None) with
| (x, Some y) -> y
| (z, None) -> z
"#,
1
}

test_expr!{ match_with_id_binding_in_two_patterns_variant,
r#"
type Option a = | None | Some a
match (Some 10, 1) with
| (Some y, 1) -> y
| (Some z, x) -> z
| (None, a) -> a
"#,
10
}
203 changes: 135 additions & 68 deletions vm/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1014,12 +1014,12 @@ enum CType {
}

use self::optimize::*;
struct ReplaceVariables<'a> {
replacements: HashMap<Symbol, Symbol>,
struct ReplaceVariables<'a, 'b> {
replacements: &'b HashMap<Symbol, Symbol>,
allocator: &'a Allocator<'a>,
}

impl<'a> Visitor<'a, 'a> for ReplaceVariables<'a> {
impl<'a, 'b> Visitor<'a, 'a> for ReplaceVariables<'a, 'b> {
type Producer = SameLifetime<'a>;

fn visit_expr(&mut self, expr: &'a Expr<'a>) -> Option<&'a Expr<'a>> {
Expand All @@ -1041,6 +1041,22 @@ impl<'a> Visitor<'a, 'a> for ReplaceVariables<'a> {
}
}

fn replace_variables<'a, 'b>(
allocator: &'a Allocator<'a>,
replacements: &'b HashMap<Symbol, Symbol>,
expr: &'a Expr<'a>,
) -> &'a Expr<'a> {
if replacements.is_empty() {
expr
} else {
ReplaceVariables {
replacements,
allocator,
}.visit_expr(expr)
.unwrap_or(expr)
}
}

/// `PatternTranslator` translated nested (AST) patterns into non-nested (core) patterns.
///
/// It does this this by looking at each nested pattern as part of an `Equation` to be solved.
Expand Down Expand Up @@ -1092,7 +1108,7 @@ impl<'a, 'e> PatternTranslator<'a, 'e> {
.map(|equation| *equation.patterns.first().unwrap())
};

let pattern = self.pattern_identifiers(first_iter());
let (pattern, replacements) = self.pattern_identifiers(first_iter());

// Gather the inner patterns so we can prepend them to equations
let temp = first_iter()
Expand Down Expand Up @@ -1150,6 +1166,8 @@ impl<'a, 'e> PatternTranslator<'a, 'e> {
let new_variables = self.insert_new_variables(&pattern, variables);

let expr = self.translate(default, &new_variables, &new_equations);
let expr = replace_variables(&self.0.allocator, &replacements, expr);

Alternative {
pattern: pattern,
expr: expr,
Expand Down Expand Up @@ -1204,7 +1222,7 @@ impl<'a, 'e> PatternTranslator<'a, 'e> {
.into_iter()
.map(|key| {
let equations = &groups[key];
let pattern = self.pattern_identifiers(
let (pattern, replacements) = self.pattern_identifiers(
equations
.iter()
.map(|equation| *equation.patterns.first().unwrap()),
Expand All @@ -1231,6 +1249,8 @@ impl<'a, 'e> PatternTranslator<'a, 'e> {
let new_variables = self.insert_new_variables(&pattern, variables);

let expr = self.translate(default, &new_variables, &new_equations);
let expr = replace_variables(&self.0.allocator, &replacements, expr);

Alternative {
pattern: pattern,
expr: expr,
Expand Down Expand Up @@ -1289,11 +1309,13 @@ impl<'a, 'e> PatternTranslator<'a, 'e> {
})
.collect::<Vec<_>>(),
);
let pattern = self.pattern_identifiers(
let (pattern, replacements) = self.pattern_identifiers(
equations
.iter()
.map(|equation| *equation.patterns.first().unwrap()),
);
let expr = replace_variables(&self.0.allocator, &replacements, expr);

let alt = Alternative {
pattern: pattern,
expr: expr,
Expand All @@ -1305,11 +1327,11 @@ impl<'a, 'e> PatternTranslator<'a, 'e> {
// EXPR // with `y`s replaced by `x`
match (&alt.pattern, variables[0]) {
(&Pattern::Ident(ref id), &Expr::Ident(ref expr_id, _)) => {
return ReplaceVariables {
replacements: collect![(id.name.clone(), expr_id.name.clone())],
allocator: &self.0.allocator,
}.visit_expr(expr)
.unwrap_or(expr);
return replace_variables(
&self.0.allocator,
&collect![(id.name.clone(), expr_id.name.clone())],
expr,
);
}
_ => (),
}
Expand Down Expand Up @@ -1548,56 +1570,89 @@ impl<'a, 'e> PatternTranslator<'a, 'e> {
}

fn extract_ident(&self, index: usize, pattern: &ast::Pattern<Symbol>) -> TypedIdent<Symbol> {
match *pattern {
ast::Pattern::Ident(ref id) => id.clone(),
ast::Pattern::As(_, ref pat) => self.extract_ident(index, &pat.value),
_ => TypedIdent {
name: Symbol::from(format!("pattern_{}", index)),
typ: pattern.env_type_of(&self.0.env),
},
}
get_ident(pattern).unwrap_or_else(|| TypedIdent {
name: Symbol::from(format!("pattern_{}", index)),
typ: pattern.env_type_of(&self.0.env),
})
}

// Gather all the identifiers of top level pattern of each of the `patterns` and create a core
// pattern.
// Nested patterns are ignored here.
fn pattern_identifiers<'b, 'p: 'b, I>(&self, patterns: I) -> Pattern
fn pattern_identifiers<'b, 'p: 'b, I>(&self, patterns: I) -> (Pattern, HashMap<Symbol, Symbol>)
where
I: IntoIterator<Item = &'b SpannedPattern<Symbol>>,
{
let mut identifiers = Vec::new();
let mut identifiers: Vec<TypedIdent<Symbol>> = Vec::new();
let mut record_fields: Vec<(TypedIdent<Symbol>, _)> = Vec::new();
let mut ident = None;
let mut core_pattern = None;

// Since we merge all patterns that match on the same thing (variants with the same tag,
// any record or tuple ...), tuple patterns
// If a field has already been seen in an earlier pattern we must make sure
// that the variable bound in this pattern/field gets replaced with the
// symbol from the earlier pattern
let mut replacements = HashMap::default();

fn add_duplicate_ident(
replacements: &mut HashMap<Symbol, Symbol>,
record_fields: &mut Vec<(TypedIdent<Symbol>, Option<Symbol>)>,
field: &Symbol,
pattern: Option<&SpannedPattern<Symbol>>,
) -> bool {
match record_fields.iter().find(|id| id.0.name == *field).cloned() {
Some(earlier_var) => {
let duplicate = match pattern {
Some(ref pattern) => get_ident(&pattern.value).map(|id| id.name),
None => Some(field.clone()),
};
if let Some(duplicate) = duplicate {
replacements.insert(duplicate, earlier_var.1.unwrap_or(earlier_var.0.name));
}
true
}
None => false,
}
}

for pattern in patterns {
match *unwrap_as(&pattern.value) {
ast::Pattern::Constructor(ref id, ref patterns) => {
identifiers.extend(
patterns
.iter()
.enumerate()
.map(|(i, pattern)| self.extract_ident(i, &pattern.value)),
);
// Just extract the patterns of the first constructor found
return Pattern::Constructor(id.clone(), identifiers);
core_pattern = Some(Pattern::Constructor(id.clone(), Vec::new()));

for (i, pattern) in patterns.iter().enumerate() {
match identifiers.get(i).map(|i| i.name.clone()) {
Some(earlier_var) => {
if let Some(duplicate) = get_ident(&pattern.value).map(|id| id.name)
{
replacements.insert(duplicate, earlier_var);
}
}
None => identifiers.push(self.extract_ident(i, &pattern.value)),
}
}
}
ast::Pattern::As(..) => unreachable!(),
ast::Pattern::Ident(ref id) => if ident.is_none() {
ident = Some(id.clone())
ast::Pattern::Ident(ref id) => if core_pattern.is_none() {
core_pattern = Some(Pattern::Ident(id.clone()));
},
ast::Pattern::Tuple { ref typ, ref elems } => {
record_fields.extend(elems.iter().zip(typ.row_iter()).enumerate().map(
|(i, (elem, field_type))| {
(
for (i, (elem, field_type)) in elems.iter().zip(typ.row_iter()).enumerate() {
if !add_duplicate_ident(
&mut replacements,
&mut record_fields,
&field_type.name,
Some(elem),
) {
record_fields.push((
TypedIdent {
name: field_type.name.clone(),
typ: field_type.typ.clone(),
},
Some(self.extract_ident(i, &elem.value).name),
)
},
));
break;
));
}
}
}
// Records need to merge the bindings of each of the patterns as we want the core
// `match` expression to just have one alternative
Expand All @@ -1611,40 +1666,52 @@ impl<'a, 'e> PatternTranslator<'a, 'e> {
ref typ,
ref fields,
..
} => {
for (i, field) in fields.iter().enumerate() {
// Don't add one field twice
if record_fields.iter().all(|id| id.0.name != field.name.value) {
let x = field
.value
.as_ref()
.map(|pattern| self.extract_ident(i, &pattern.value).name);
let field_type = remove_aliases_cow(&self.0.env, typ)
.row_iter()
.find(|f| f.name.name_eq(&field.name.value))
.map(|f| f.typ.clone())
.unwrap_or_else(|| Type::hole());
record_fields.push((
TypedIdent {
name: field.name.value.clone(),
typ: field_type,
},
x,
));
}
} => for (i, field) in fields.iter().enumerate() {
if !add_duplicate_ident(
&mut replacements,
&mut record_fields,
&field.name.value,
field.value.as_ref(),
) {
let x = field
.value
.as_ref()
.map(|pattern| self.extract_ident(i, &pattern.value).name);
let field_type = remove_aliases_cow(&self.0.env, typ)
.row_iter()
.find(|f| f.name.name_eq(&field.name.value))
.map(|f| f.typ.clone())
.unwrap_or_else(|| Type::hole());
record_fields.push((
TypedIdent {
name: field.name.value.clone(),
typ: field_type,
},
x,
));
}
}
},
ast::Pattern::Literal(_) | ast::Pattern::Error => (),
}
}
if record_fields.is_empty() {
match ident {
Some(ident) => Pattern::Ident(ident),
None => Pattern::Record(record_fields),
let pattern = match core_pattern {
Some(mut p) => {
if let Pattern::Constructor(_, ref mut ids) = p {
*ids = identifiers
}
p
}
} else {
Pattern::Record(record_fields)
}
None => Pattern::Record(record_fields),
};
(pattern, replacements)
}
}

fn get_ident(pattern: &ast::Pattern<Symbol>) -> Option<TypedIdent<Symbol>> {
match *pattern {
ast::Pattern::Ident(ref id) => Some(id.clone()),
ast::Pattern::As(_, ref pat) => get_ident(&pat.value),
_ => None,
}
}

Expand Down

0 comments on commit f865777

Please sign in to comment.