diff --git a/check/src/substitution.rs b/check/src/substitution.rs index b41ef90663..1f21a73c6f 100644 --- a/check/src/substitution.rs +++ b/check/src/substitution.rs @@ -341,6 +341,11 @@ impl Substitution { union.get_mut(other as usize).level = level; } + pub fn set_level(&self, var: u32, level: u32) { + let mut union = self.union.borrow_mut(); + union.get_mut(var as usize).level = level; + } + pub fn get_level(&self, mut var: u32) -> u32 { if let Some(v) = self.find_type_for_var(var) { var = v.get_var().map_or(var, |v| v.get_id()); @@ -552,10 +557,7 @@ impl Substitution { typ.get_var().map(|x| x.get_id()), resolved.get_var().map(|x| x.get_id()), ) { - (Some(x), Some(y)) if x > y => { - typ = Cow::Owned(resolved); - } - (_, None) => { + (Some(_), Some(_)) | (_, None) => { typ = Cow::Owned(resolved); } _ => (), diff --git a/check/src/typecheck.rs b/check/src/typecheck.rs index 111ad23630..da1dc903dc 100644 --- a/check/src/typecheck.rs +++ b/check/src/typecheck.rs @@ -348,8 +348,7 @@ impl<'a> Typecheck<'a> { Ok(typ) } None => { - // Don't report global variables inserted by the `import!` macro as non-existing - // existing + // Don't report global variables inserted by the `import!` macro as undefined // (if they don't exist the error will already have been reported by the macro) if id.is_global() { Ok(self.subs.new_var()) @@ -638,7 +637,8 @@ impl<'a> Typecheck<'a> { let returned_type; loop { let expected_type = expected_type.map(|t| self.skolemize(t)); - match self.typecheck_(expr, expected_type.as_ref()) { + let mut expected_type = expected_type.as_ref(); + match self.typecheck_(expr, &mut expected_type) { Ok(tailcall) => { match tailcall { TailCall::TailCall => { @@ -655,7 +655,13 @@ impl<'a> Typecheck<'a> { scope_count += 1; } TailCall::Type(typ) => { - returned_type = typ; + returned_type = match expected_type { + Some(expected_type) => { + let level = self.subs.var_id(); + self.merge_signature(expr.span, level, &expected_type, typ) + } + None => typ, + }; break; } } @@ -676,10 +682,12 @@ impl<'a> Typecheck<'a> { returned_type } + /// `expected_type` should be set to `None` if subsumption is done with it (to prevent us from + /// doing it twice) fn typecheck_( &mut self, expr: &mut SpannedExpr, - expected_type: Option<&ArcType>, + expected_type: &mut Option<&ArcType>, ) -> Result> { match expr.value { Expr::Ident(ref mut id) => { @@ -706,8 +714,8 @@ impl<'a> Typecheck<'a> { self.unify_span(expr_check_span(pred), &bool_type, pred_type); // Both branches must unify to the same type - let true_type = self.typecheck_opt(&mut **if_true, expected_type); - let false_type = self.typecheck_opt(&mut **if_false, expected_type); + let true_type = self.typecheck_opt(&mut **if_true, expected_type.clone()); + let false_type = self.typecheck_opt(&mut **if_false, expected_type.take()); let true_type = self.instantiate_generics(&true_type); let false_type = self.instantiate_generics(&false_type); @@ -753,7 +761,7 @@ impl<'a> Typecheck<'a> { } => { *typ = match exprs.len() { 0 => Type::unit(), - 1 => self.typecheck_opt(&mut exprs[0], expected_type), + 1 => self.typecheck_opt(&mut exprs[0], expected_type.take()), _ => { let fields = exprs .iter_mut() @@ -775,6 +783,8 @@ impl<'a> Typecheck<'a> { let typ = self.infer_expr(&mut **expr); let mut expected_alt_type = expected_type.cloned(); + let expected_type = expected_type.take(); + for alt in alts.iter_mut() { self.enter_scope(); self.typecheck_pattern(&mut alt.pattern, typ.clone()); @@ -855,10 +865,13 @@ impl<'a> Typecheck<'a> { Expr::Lambda(ref mut lambda) => { let loc = format!("{}.lambda:{}", self.symbols.module(), expr.span.start); lambda.id.name = self.symbols.symbol(loc); + let level = self.subs.var_id(); let function_type = expected_type .cloned() .unwrap_or_else(|| self.subs.new_var()); - let typ = self.typecheck_lambda(function_type, &mut lambda.args, &mut lambda.body); + let mut typ = + self.typecheck_lambda(function_type, &mut lambda.args, &mut lambda.body); + self.generalize_type(level, &mut typ); lambda.id.typ = typ.clone(); Ok(TailCall::Type(typ)) } @@ -896,23 +909,34 @@ impl<'a> Typecheck<'a> { for field in fields { let level = self.subs.var_id(); + let name = &field.name.value; + let expected_field_type = expected_type + .and_then(|expected_type| { + expected_type + .row_iter() + .find(|expected_field| expected_field.name.name_eq(&name)) + }) + .map(|field| &field.typ); + let typ = match field.value { Some(ref mut expr) => { - let name = &field.name.value; - let expected_type = expected_type - .and_then(|expected_type| { - expected_type - .row_iter() - .find(|expected_field| expected_field.name.name_eq(&name)) - }) - .map(|field| &field.typ); - - let mut typ = self.typecheck_opt(expr, expected_type); + let mut typ = self.typecheck_opt(expr, expected_field_type); self.generalize_type(level, &mut typ); new_skolem_scope(&self.subs, &FnvMap::default(), &typ) } - None => self.find(&field.name.value)?, + None => { + let typ = self.find(&field.name.value)?; + match expected_field_type { + Some(expected_field_type) => self.merge_signature( + field.name.span, + level, + &expected_field_type, + typ, + ), + None => typ, + } + } }; if self.error_on_duplicated_field(&mut duplicated_fields, field.name.clone()) { new_fields.push(Field::new(field.name.value.clone(), typ)); @@ -970,7 +994,10 @@ impl<'a> Typecheck<'a> { for expr in exprs { self.infer_expr(expr); } - Ok(TailCall::Type(self.typecheck_opt(last, expected_type))) + Ok(TailCall::Type(self.typecheck_opt( + last, + expected_type.take(), + ))) } Expr::Do(Do { ref mut id, @@ -1037,6 +1064,7 @@ impl<'a> Typecheck<'a> { where I: IntoIterator>, { + func_type = self.new_skolem_scope(&func_type); for arg in args { let f = self.type_cache .function(once(self.subs.new_var()), self.subs.new_var()); @@ -1378,24 +1406,16 @@ impl<'a> Typecheck<'a> { bind.resolved_type = typ; } - bind.resolved_type = self.new_skolem_scope_signature(&bind.resolved_type); - self.typecheck(&mut bind.expr, &bind.resolved_type) + let typ = self.new_skolem_scope_signature(&bind.resolved_type); + self.typecheck(&mut bind.expr, &typ) } else { - bind.resolved_type = self.new_skolem_scope_signature(&bind.resolved_type); - let function_type = self.instantiate_generics(&bind.resolved_type); + let typ = self.new_skolem_scope_signature(&bind.resolved_type); + let function_type = self.skolemize(&typ); self.typecheck_lambda(function_type, &mut bind.args, &mut bind.expr) }; debug!("let {:?} : {}", bind.name, typ); - let bind_span = Span::new( - bind.name.span.start, - bind.args - .last() - .map_or(bind.name.span.end, |last_arg| last_arg.span.end), - ); - typ = self.merge_signature(bind_span, level, &bind.resolved_type, typ); - if !is_recursive { // Merge the type declaration and the actual type debug!("Generalize at {} = {}", level, bind.resolved_type); @@ -1616,9 +1636,11 @@ impl<'a> Typecheck<'a> { let typ = self.instantiate_generics(&typ); let record_type = self.remove_alias(typ.clone()); with_pattern_types(fields, &record_type, |field_name, binding, field_type| { + let mut field_type = field_type.clone(); + self.generalize_type(level, &mut field_type); match *binding { Some(ref mut pat) => { - self.finish_pattern(level, pat, field_type); + self.finish_pattern(level, pat, &field_type); } None => { self.environment @@ -1628,7 +1650,7 @@ impl<'a> Typecheck<'a> { .typ = field_type.clone(); debug!("{}: {}", field_name, field_type); - self.intersect_type(level, field_name, field_type); + self.intersect_type(level, field_name, &field_type); } } }); @@ -1642,7 +1664,9 @@ impl<'a> Typecheck<'a> { let typ = self.top_skolem_scope(typ); let typ = self.instantiate_generics(&typ); for (elem, field) in elems.iter_mut().zip(typ.row_iter()) { - self.finish_pattern(level, elem, &field.typ); + let mut field_type = field.typ.clone(); + self.generalize_type(level, &mut field_type); + self.finish_pattern(level, elem, &field_type); } } Pattern::Constructor(ref id, ref mut args) => { @@ -1670,8 +1694,8 @@ impl<'a> Typecheck<'a> { // Only allow overloading for bindings whose types which do not contain type variables // It might be possible to lift this restriction but currently it causes problems // which I am not sure how to solve - debug!("Looking for intersection `{}`", symbol_type); if existing_types.len() >= 2 { + debug!("Looking for intersection `{}`", symbol_type); let existing_binding = &existing_types[existing_types.len() - 2]; debug!( "Intersect `{}`\n{} ∩ {}", @@ -1709,7 +1733,7 @@ impl<'a> Typecheck<'a> { { let constraints: FnvMap<_, _> = intersection_constraints .into_iter() - .map(|((l, mut r), name)| { + .map(|((mut l, mut r), name)| { let constraints = match *l { Type::Generic(ref gen) => existing_binding.constraints.get(&gen.id), Type::Skolem(ref skolem) => existing_binding.constraints.get(&skolem.name), @@ -1728,6 +1752,7 @@ impl<'a> Typecheck<'a> { _ => None, }; + self.generalize_type(level, &mut l); self.generalize_type(level, &mut r); ( @@ -1925,6 +1950,11 @@ impl<'a> Typecheck<'a> { id: skolem.name.clone(), kind: skolem.kind.clone(), }; + + if self.type_variables.get(&generic.id).is_none() { + unbound_variables.insert(generic.id.clone(), generic.clone()); + } + Some(Type::generic(generic)) } @@ -2328,7 +2358,13 @@ impl<'a, 'b> Iterator for FunctionArgIter<'a, 'b> { Some(typ) => (None, typ.clone()), None => return None, }, - None => (Some(self.tc.subs.new_var()), self.tc.subs.new_var()), + None => { + let arg = self.tc.subs.new_var(); + let ret = self.tc.subs.new_var(); + let f = self.tc.type_cache.function(Some(arg.clone()), ret.clone()); + self.tc.unify(&self.typ, f).unwrap(); + (Some(arg), ret) + } }, }; self.typ = new; diff --git a/check/src/unify_type.rs b/check/src/unify_type.rs index d944a61b7b..a36591ade0 100644 --- a/check/src/unify_type.rs +++ b/check/src/unify_type.rs @@ -63,18 +63,6 @@ impl<'a> State<'a> { } } - fn replace_forall( - &mut self, - typ: &ArcType, - named_variables: &mut FnvMap, - ) -> ArcType { - if self.in_alias { - typ.instantiate_generics(named_variables) - } else { - typ.skolemize(named_variables) - } - } - fn remove_aliases( &mut self, subs: &Substitution, @@ -91,7 +79,17 @@ impl<'a> State<'a> { Some(mut typ) => { loop { typ = types::walk_move_type(typ.clone(), &mut |typ| match **typ { - Type::Forall(_, _, None) => Some(typ.instantiate(subs, &FnvMap::default())), + Type::Forall(_, _, None) => { + let typ = new_skolem_scope(subs, &FnvMap::default(), typ); + if let Type::Forall(_, _, Some(ref vars)) = *typ { + for var in vars { + if let Type::Variable(ref var) = **var { + subs.set_level(var.id, 0); + } + } + } + Some(typ) + } _ => None, }); if let Some(alias_id) = typ.alias_ident() { @@ -139,8 +137,7 @@ where TypeError::FieldMismatch(ref l, ref r) => write!( f, "Field names in record do not match.\n\tExpected: {}\n\tFound: {}", - l, - r + l, r ), TypeError::UndefinedType(ref id) => write!(f, "Type `{}` does not exist.", id), TypeError::SelfRecursive(ref id) => write!( @@ -390,9 +387,9 @@ where let mut named_variables = FnvMap::default(); if unifier.state.in_alias { - let l = unifier.state.replace_forall(expected, &mut named_variables); + let l = expected.skolemize(&mut named_variables); named_variables.clear(); - let r = unifier.state.replace_forall(actual, &mut named_variables); + let r = actual.skolemize(&mut named_variables); Ok(unifier.try_match_res(&l, &r)?.map(|inner_type| { reconstruct_forall(unifier.state.subs, params, inner_type, vars) @@ -410,9 +407,7 @@ where }), ) })); - let l = unifier - .state - .replace_forall(expected_iter.typ, &mut named_variables); + let l = expected_iter.typ.skolemize(&mut named_variables); named_variables.clear(); let mut actual_iter = actual.forall_scope_iter(); @@ -430,10 +425,7 @@ where ) }, )); - let r = unifier - .state - .replace_forall(actual_iter.typ, &mut named_variables); - + let r = actual_iter.typ.skolemize(&mut named_variables); Ok(unifier.try_match_res(&l, &r)?.map(|inner_type| { reconstruct_forall(unifier.state.subs, params, inner_type, vars) @@ -442,16 +434,14 @@ where } (&Type::Forall(ref params, _, Some(ref vars)), _) => { - let l = unifier - .state - .replace_forall(expected, &mut FnvMap::default()); - Ok(unifier.try_match_res(&l, &actual)?.map(|inner_type| { - reconstruct_forall(unifier.state.subs, params, inner_type, vars) - })) + let l = expected.skolemize(&mut FnvMap::default()); + Ok(unifier + .try_match_res(&l, &actual)? + .map(|inner_type| reconstruct_forall(unifier.state.subs, params, inner_type, vars))) } (_, &Type::Forall(_, _, Some(_))) => { - let r = unifier.state.replace_forall(actual, &mut FnvMap::default()); + let r = actual.skolemize(&mut FnvMap::default()); Ok(unifier.try_match_res(expected, &r)?) } @@ -834,7 +824,6 @@ where Ok((l, r)) } - // HACK // Currently the substitution assumes that once a variable has been unified to a // concrete type it cannot be unified to another type later. @@ -868,9 +857,8 @@ pub fn new_skolem_scope( constraints: &FnvMap>, typ: &ArcType, ) -> ArcType { - types::walk_move_type( - typ.clone(), - &mut |typ| if let Type::Forall(ref params, ref inner_type, None) = **typ { + types::walk_move_type(typ.clone(), &mut |typ| { + if let Type::Forall(ref params, ref inner_type, None) = **typ { let mut skolem = Vec::new(); for param in params { let constraint = constraints.get(¶m.id).cloned(); @@ -886,8 +874,8 @@ pub fn new_skolem_scope( ))) } else { None - }, - ) + } + }) } pub fn top_skolem_scope( @@ -974,9 +962,9 @@ impl<'a, 'e> Unifier, ArcType> for Merge<'e> { (&Type::Skolem(ref skolem), &Type::Variable(ref r_var)) if subs.get_level(skolem.id) > r_var.id => { - return Err(UnifyError::Other( - TypeError::UnableToGeneralize(skolem.name.clone()), - )); + return Err(UnifyError::Other(TypeError::UnableToGeneralize( + skolem.name.clone(), + ))); } (&Type::Generic(ref l_gen), &Type::Variable(ref r_var)) => { let left = match unifier.unifier.variables.get(&l_gen.id) { @@ -989,14 +977,14 @@ impl<'a, 'e> Unifier, ArcType> for Merge<'e> { } // `r_var` is outside the scope of the generic variable. Type::Variable(ref var) if var.id > r_var.id => { - return Err(UnifyError::Other( - TypeError::UnableToGeneralize(l_gen.id.clone()), - )); + return Err(UnifyError::Other(TypeError::UnableToGeneralize( + l_gen.id.clone(), + ))); } - Type::Skolem(ref skolem) if skolem.id > r_var.id => { - return Err(UnifyError::Other( - TypeError::UnableToGeneralize(l_gen.id.clone()), - )); + Type::Skolem(ref skolem) if subs.get_level(skolem.id) > r_var.id => { + return Err(UnifyError::Other(TypeError::UnableToGeneralize( + l_gen.id.clone(), + ))); } _ => l, } @@ -1029,12 +1017,12 @@ impl<'a, 'e> Unifier, ArcType> for Merge<'e> { // // `Typecheck::find` // { id, compose, (<<) } // ``` - (&Type::Forall(ref params, ref l, None), _) => { - unifier.unifier.variables.extend( - params - .iter() - .map(|param| (param.id.clone(), subs.new_var())), - ); + (&Type::Forall(ref params, ref l, _), _) => { + let mut variables = params + .iter() + .map(|param| (param.id.clone(), subs.new_var())) + .collect(); + let l = l.instantiate_generics(&mut variables); unifier.try_match_res(&l, r) } (_, &Type::Variable(ref r)) => { diff --git a/check/tests/fail.rs b/check/tests/fail.rs index 4c145913ba..12a66ef0d0 100644 --- a/check/tests/fail.rs +++ b/check/tests/fail.rs @@ -29,7 +29,7 @@ match { x = 1 } with } #[test] -fn undefined_type() { +fn undefined_type_not_in_scope() { let _ = env_logger::init(); let text = r#" let x = @@ -402,7 +402,7 @@ eq (A 0) (B 0.0) assert_eq!( &*format!("{}", result.unwrap_err()).replace("\t", " "), - r#"test:Line: 5, Column: 10: Expected the following types to be equal + r#"test:Line: 5, Column: 11: Expected the following types to be equal Expected: test.A Found: test.B 1 errors were found during unification: @@ -410,12 +410,11 @@ Types do not match: Expected: test.A Found: test.B eq (A 0) (B 0.0) - ^~~~~~~ + ^~~~~ "# ); } - #[test] fn long_type_error_format() { let long_type: ArcType = Type::function( @@ -564,3 +563,36 @@ type Test = | Test In assert_err!(result, UndefinedType(..)); } + +#[test] +fn foldable_bug() { + let _ = ::env_logger::init(); + + let text = r#" +type Array a = { x : a } + +type Foldable (f : Type -> Type) = { + foldl : forall a b . (b -> a -> b) -> b -> f a -> b +} + +let any x = any x + +let foldable : Foldable Array = + + let foldl : forall a b . (a -> b -> b) -> b -> Array a -> b = any () + + { foldl } +() +"#; + let result = support::typecheck(text); + + assert_multi_unify_err!( + result, + [ + TypeMismatch(..), + TypeMismatch(..), + TypeMismatch(..), + TypeMismatch(..) + ] + ); +} diff --git a/check/tests/forall.rs b/check/tests/forall.rs index cb7b205a08..5a5a813a09 100644 --- a/check/tests/forall.rs +++ b/check/tests/forall.rs @@ -335,7 +335,6 @@ fn field_access_tuple() { assert_eq!(result, Ok(Type::int())); } - #[test] fn unit_tuple_match() { let _ = ::env_logger::init(); @@ -415,7 +414,6 @@ x assert_eq!(result, Ok(Type::int())); } - #[test] fn record_expr_base() { let _ = ::env_logger::init(); @@ -675,7 +673,6 @@ let { List, f } = make 1 assert!(result.is_ok(), "{}", result.unwrap_err()); } - // Unsure if this should be able to compile as is (without type annotations) #[test] #[ignore] @@ -984,7 +981,6 @@ let show : Show a -> Show (List a) = \d -> assert!(result.is_ok(), "{}", result.unwrap_err()); } - #[test] fn show_list_bug_with_as_pattern() { let _ = ::env_logger::init(); @@ -1010,3 +1006,62 @@ list.show int_show assert!(result.is_ok(), "{}", result.unwrap_err()); } + +#[test] +fn generalize_record_unpacks() { + let _ = ::env_logger::init(); + + let text = r#" +type Semigroup a = { + append : a -> a -> a +} + +/// A linked list type +type List a = | Nil | Cons a (List a) + +let semigroup : Semigroup (List a) = + let append xs ys = + match xs with + | Cons x zs -> Cons x (append zs ys) + | Nil -> ys + + { append } + +let { append } = semigroup + +append (Cons 1 Nil) Nil +append (Cons "" Nil) Nil +"#; + let result = support::typecheck(text); + + assert!(result.is_ok(), "{}", result.unwrap_err()); +} + +#[test] +#[ignore] +fn generalize_tuple_unpacks() { + let _ = ::env_logger::init(); + + let text = r#" +type Semigroup a = (a -> a -> a, Int) + +/// A linked list type +type List a = | Nil | Cons a (List a) + +let semigroup : Semigroup (List a) = + let append xs ys = + match xs with + | Cons x zs -> Cons x (append zs ys) + | Nil -> ys + + (append, 0) + +let (append, _) = semigroup + +append (Cons 1 Nil) Nil +append (Cons "" Nil) Nil +"#; + let result = support::typecheck(text); + + assert!(result.is_ok(), "{}", result.unwrap_err()); +} diff --git a/check/tests/support/mod.rs b/check/tests/support/mod.rs index 09d41dff5a..28dc3ff68a 100644 --- a/check/tests/support/mod.rs +++ b/check/tests/support/mod.rs @@ -390,9 +390,10 @@ macro_rules! assert_multi_unify_err { } None => { assert!(false, - "Found {} less errors than expected at {}.\n\ + "Found {} errors but expected {} than expected at {}.\n\ Errors:\n{}\nbut expected {}", - expected_count - errors.len(), + errors.len(), + expected_count, i, error, stringify!($id)