diff --git a/check/src/implicits.rs b/check/src/implicits.rs index 7056c411b8..43acd3ca44 100644 --- a/check/src/implicits.rs +++ b/check/src/implicits.rs @@ -1,4 +1,4 @@ -use std::{convert::TryInto, fmt, rc::Rc, sync::Arc}; +use std::{fmt, rc::Rc, sync::Arc}; use {itertools::Itertools, smallvec::SmallVec}; @@ -64,11 +64,12 @@ fn split_type<'a>( type ImplicitBinding = (Rc<[TypedIdent]>, RcType); +#[derive(Clone)] pub struct Partition { - partition: FnvMap>, + partition: rpds::HashTrieMap>, // The partitioning should be very fine grained so we usually have very few elements in each // partition - rest: SmallVec<[(Level, T); 3]>, + rest: SmallVec<[T; 3]>, } impl fmt::Debug for Partition @@ -92,7 +93,7 @@ where "[{}]", self.rest .iter() - .format_with(",", |(_, i), f| f(&format_args!("{:?}", i))) + .format_with(",", |i, f| f(&format_args!("{:?}", i))) ), ) .finish() @@ -116,10 +117,9 @@ impl fmt::Display for Partition { &format_args!( "[{}]", self.rest.iter().format_with(",", |i, f| f(&format_args!( - "@{:?} {}: {}", - i.0, - (i.1).0.iter().map(|i| &i.name).format("."), - (i.1).1 + "{}: {}", + i.0.iter().map(|i| &i.name).format("."), + i.1 ))) ), ) @@ -137,25 +137,26 @@ impl Default for Partition { } impl Partition { - fn insert(&mut self, subs: &Substitution, typ: &RcType, level: Level, value: T) + fn insert(&mut self, subs: &Substitution, typ: &RcType, value: T) where T: Clone, { - let mut partition = self; - for symbol in spliterator(subs, typ) { - partition = partition.partition.entry(symbol).or_default(); - partition.rest.push((level, value.clone())); - } + let iter = spliterator(subs, typ); + self.insert_(iter, &value); } - fn remove(&mut self, subs: &Substitution, typ: &RcType) { - let mut partition = self; - for symbol in spliterator(subs, typ) { - partition = partition - .partition - .get_mut(&symbol) - .expect("Entry from insert call"); - partition.rest.pop(); + fn insert_(&mut self, mut iter: impl Iterator, value: &T) -> bool + where + T: Clone, + { + if let Some(symbol) = iter.next() { + let mut partition = self.partition.get(&symbol).cloned().unwrap_or_default(); + partition.insert_(iter, value); + partition.rest.push(value.clone()); + self.partition.insert_mut(symbol, partition); + true + } else { + false } } @@ -163,14 +164,23 @@ impl Partition { &'a self, subs: &Substitution, typ: &RcType, - implicit_bindings_level: Level, + ) -> impl DoubleEndedIterator + where + T: fmt::Debug, + { + let mut candidates = Vec::new(); + self.get_candidates_(subs, typ, &mut |t| candidates.push(t)); + candidates.into_iter() + } + + fn get_candidates_<'a>( + &'a self, + subs: &Substitution, + typ: &RcType, consumer: &mut impl FnMut(&'a T), ) where T: fmt::Debug, { - fn f(t: &(Level, U)) -> &U { - &t.1 - } let mut partition = self; for symbol in spliterator(subs, typ) { match partition.partition.get(&symbol) { @@ -178,100 +188,56 @@ impl Partition { None => break, } } - let end = partition - .rest - .iter() - .rposition(|(level, _)| *level <= implicit_bindings_level) - .map_or(0, |i| i + 1); - partition.rest[..end].iter().map(f).for_each(&mut *consumer); - } -} - -impl Partition { - fn update(&mut self, f: &mut F) - where - F: FnMut(&Symbol) -> Option, - { - for partition in self.partition.values_mut() { - partition.update(f); - } - - for (_, (path, typ)) in &mut self.rest { - if path.len() == 1 { - let name = &path[0].name; - if let Some(t) = f(name) { - *typ = t; - } - } - } + partition.rest.iter().for_each(&mut *consumer); } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub(crate) struct Level(u32); - -#[derive(Default, Debug)] +#[derive(Debug)] pub(crate) struct ImplicitBindings { - pub partition: Partition, - partition_insertions: Vec)>>, + pub partition: Vec>, pub definitions: ScopedMap, } -impl fmt::Display for ImplicitBindings { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.partition) +impl Default for ImplicitBindings { + fn default() -> Self { + Self::new() } } impl ImplicitBindings { + fn new() -> Self { + Self { + partition: vec![Default::default()], + definitions: Default::default(), + } + } + fn insert( &mut self, subs: &Substitution, - definition: Option<&Symbol>, path: &[TypedIdent], typ: &RcType, ) { - let level = Level(self.partition_insertions.len().try_into().unwrap()); - self.partition - .insert(subs, typ, level, (Rc::from(&path[..]), typ.clone())); - - self.partition_insertions - .push(Some((typ.clone(), definition.cloned()))); + .last_mut() + .unwrap() + .insert(subs, typ, (Rc::from(&path[..]), typ.clone())); } - pub fn update(&mut self, mut f: F) - where - F: FnMut(&Symbol) -> Option, - { - self.partition.update(&mut f); - } - - fn get_candidates<'a>( - &'a self, - subs: &Substitution, - typ: &RcType, - level: Level, - ) -> impl DoubleEndedIterator { - let mut candidates = Vec::new(); - self.partition - .get_candidates(subs, typ, level, &mut |bind| candidates.push(bind.clone())); - candidates.into_iter() + pub fn partition(&self) -> &Partition { + self.partition.last().expect("bindings") } pub fn enter_scope(&mut self) { + if let Some(bind) = self.partition.last().cloned() { + self.partition.push(bind); + } self.definitions.enter_scope(); - self.partition_insertions.push(None); } - pub fn exit_scope(&mut self, subs: &Substitution) { + pub fn exit_scope(&mut self) { + self.partition.pop(); self.definitions.exit_scope(); - while let Some(Some((typ, definition))) = self.partition_insertions.pop() { - if let Some(definition) = definition { - self.definitions.remove(&definition); - } - self.partition.remove(subs, &typ); - } } } @@ -356,25 +322,23 @@ struct ResolveImplicitsVisitor<'a, 'b: 'a> { impl<'a, 'b> ResolveImplicitsVisitor<'a, 'b> { fn resolve_implicit( &mut self, - implicit_bindings_level: Level, + implicit_bindings: &Partition, expr: &SpannedExpr, id: &TypedIdent, ) -> Option> { debug!( "Resolving {} against:\n{}", id.typ, - self.tc - .implicit_resolver - .implicit_bindings - .get_candidates(&self.tc.subs, &id.typ, implicit_bindings_level) - .map(|t| t.1) + implicit_bindings + .get_candidates(&self.tc.subs, &id.typ,) + .map(|t| &t.1) .format("\n") ); self.tc.implicit_resolver.visited.clear(); let span = expr.span; let mut to_resolve = Vec::new(); match self.find_implicit( - implicit_bindings_level, + implicit_bindings, &mut to_resolve, &Demand { reason: Default::default(), @@ -392,8 +356,8 @@ impl<'a, 'b> ResolveImplicitsVisitor<'a, 'b> { ); let resolution_result = match self.resolve_implicit_application( + implicit_bindings, 0, - implicit_bindings_level, span, &path_of_candidate, &to_resolve, @@ -441,13 +405,13 @@ impl<'a, 'b> ResolveImplicitsVisitor<'a, 'b> { fn resolve_implicit_application( &mut self, + implicit_bindings: &Partition, level: u32, - implicit_bindings_level: Level, span: Span, path: &[TypedIdent], to_resolve: &[Demand], ) -> Result>> { - self.resolve_implicit_application_(level, implicit_bindings_level, span, path, to_resolve) + self.resolve_implicit_application_(implicit_bindings, level, span, path, to_resolve) .map_err(|mut err| { if let ErrorKind::LoopInImplicitResolution(ref mut paths) = err.kind { paths.push(path.iter().map(|id| &id.name).format(".").to_string()); @@ -458,8 +422,8 @@ impl<'a, 'b> ResolveImplicitsVisitor<'a, 'b> { fn resolve_implicit_application_( &mut self, + implicit_bindings: &Partition, level: u32, - implicit_bindings_level: Level, span: Span, path: &[TypedIdent], to_resolve: &[Demand], @@ -494,12 +458,12 @@ impl<'a, 'b> ResolveImplicitsVisitor<'a, 'b> { let mut to_resolve = Vec::new(); let result = self - .find_implicit(implicit_bindings_level, &mut to_resolve, demand) + .find_implicit(implicit_bindings, &mut to_resolve, demand) .and_then(|path| { debug!("Success! Resolving arguments"); self.resolve_implicit_application( + implicit_bindings, level + 1, - implicit_bindings_level, span, &path, &to_resolve, @@ -559,15 +523,12 @@ impl<'a, 'b> ResolveImplicitsVisitor<'a, 'b> { fn find_implicit( &mut self, - implicit_bindings_level: Level, + implicit_bindings: &Partition, to_resolve: &mut Vec, demand: &Demand, ) -> Result]>> { - let mut candidates = self - .tc - .implicit_resolver - .implicit_bindings - .get_candidates(&self.tc.subs, &demand.constraint, implicit_bindings_level) + let mut candidates = implicit_bindings + .get_candidates(&self.tc.subs, &demand.constraint) .rev(); let mut snapshot = Some(self.tc.subs.snapshot()); let found_candidate = candidates.by_ref().find(|x| { @@ -667,19 +628,19 @@ impl<'a, 'b, 'c> MutVisitor<'c> for ResolveImplicitsVisitor<'a, 'b> { fn visit_expr(&mut self, expr: &mut SpannedExpr) { let mut replacement = None; if let Expr::Ident(ref id) = expr.value { - let implicit_bindings_level = self + let implicit_vars = self .tc .implicit_resolver .implicit_vars .get(&id.name) .cloned(); - if let Some(implicit_bindings_level) = implicit_bindings_level { + if let Some(implicit_vars) = implicit_vars { let typ = id.typ.clone(); let id = TypedIdent { name: id.name.clone(), typ: typ, }; - replacement = self.resolve_implicit(implicit_bindings_level, expr, &id); + replacement = self.resolve_implicit(&implicit_vars, expr, &id); } } if let Some(replacement) = replacement { @@ -696,7 +657,7 @@ pub struct ImplicitResolver<'a> { pub(crate) metadata: &'a mut FnvMap>, environment: &'a dyn TypecheckEnv, pub(crate) implicit_bindings: ImplicitBindings, - pub(crate) implicit_vars: ScopedMap, + pub(crate) implicit_vars: ScopedMap>, visited: ScopedMap, Box<[RcType]>>, alias_resolver: resolve::AliasRemover, path: Vec>, @@ -778,11 +739,10 @@ impl<'a> ImplicitResolver<'a> { let opt = self.try_create_implicit(metadata, typ); - if let Some(definition) = opt { + if let Some(_) = opt { let typ = subs.forall(forall_params.iter().cloned().collect(), typ.clone()); - self.implicit_bindings - .insert(subs, definition, &self.path, &typ); + self.implicit_bindings.insert(subs, &self.path, &typ); self.add_implicits_of_record_rec(subs, &typ, metadata, forall_params) } @@ -872,13 +832,8 @@ impl<'a> ImplicitResolver<'a> { pub fn make_implicit_ident(&mut self, _typ: &RcType) -> Symbol { let name = Symbol::from("implicit_arg"); - let implicits_revision = Level( - self.implicit_bindings - .partition_insertions - .len() - .try_into() - .unwrap(), - ); + let implicits_revision = self.implicit_bindings.partition().clone(); + self.implicit_vars.insert(name.clone(), implicits_revision); name } @@ -887,8 +842,8 @@ impl<'a> ImplicitResolver<'a> { self.implicit_bindings.enter_scope(); } - pub fn exit_scope(&mut self, subs: &Substitution) { - self.implicit_bindings.exit_scope(subs); + pub fn exit_scope(&mut self) { + self.implicit_bindings.exit_scope(); } } diff --git a/check/src/typecheck.rs b/check/src/typecheck.rs index bd947cf688..4b2c511189 100644 --- a/check/src/typecheck.rs +++ b/check/src/typecheck.rs @@ -356,11 +356,13 @@ impl<'a> Typecheck<'a> { fn enter_scope(&mut self) { self.environment.stack.enter_scope(); self.environment.stack_types.enter_scope(); + self.implicit_resolver.enter_scope(); } fn exit_scope(&mut self) { self.environment.stack.exit_scope(); self.environment.stack_types.exit_scope(); + self.implicit_resolver.exit_scope(); } fn generalize_binding( @@ -1872,7 +1874,6 @@ impl<'a> Typecheck<'a> { let mut types = Vec::new(); for (i, bind) in bindings.iter_mut().enumerate() { - self.implicit_resolver.enter_scope(); // Functions which are declared as `let f x = ...` are allowed to be self // recursive let typ = if !is_recursive { @@ -1977,11 +1978,6 @@ impl<'a> Typecheck<'a> { debug!("End generalize recursive"); } - // Update the implicit bindings with the generalized types we just created - let stack = &self.environment.stack; - self.implicit_resolver - .implicit_bindings - .update(|name| stack.get(name).map(|b| b.typ.concrete.clone())); debug!("Typecheck `in`"); self.environment.type_variables.exit_scope(); @@ -3253,10 +3249,6 @@ fn generalize_binding( binding: &mut ValueBinding, ) { crate::implicits::resolve(generalizer.tc, &mut binding.expr); - generalizer - .tc - .implicit_resolver - .exit_scope(&generalizer.tc.subs); generalizer.generalize_type_top(resolved_type); } diff --git a/check/tests/implicits.rs b/check/tests/implicits.rs index 954c9ebffe..d676f7a5be 100644 --- a/check/tests/implicits.rs +++ b/check/tests/implicits.rs @@ -75,7 +75,7 @@ f 42 "#; let (expr, result) = support::typecheck_expr(text); - assert_eq!(result, Ok(Type::int())); + assert_req!(result, Ok(Type::int())); assert_eq!( r#"let f ?__implicit_arg y : [Int] -> Int -> Int = y #[implicit] @@ -220,7 +220,7 @@ f (Test ()) let result = support::typecheck(text); let test = support::alias_variant_implicit("Test", &[], &[("Test", &[Type::unit()])], true); - assert_eq!(result, Ok(test)); + assert_req!(result, Ok(test)); } test_check! { @@ -1144,3 +1144,35 @@ x "#, "test.TestInt" } + +test_check! { +implicits_in_multiple_scopes, +r#" +#[implicit] +type Test a = { x : a } + +let module = + let test: Test Int = { x = 0 } + { test } + +let module2 = + let test: Test Int = { x = 1 } + { test } + +let test ?t : [Test a] -> a = t.x + +[ + (\_ -> + let { ? } = module + let x: Int = test + x + ), + (\_ -> + let { ? } = module2 + let x: Int = test + x + ), +] +"#, +"forall a . Array (a -> Int)" +} diff --git a/tests/pass/arithmetic.glu b/tests/pass/arithmetic.glu index d959ee989c..1a71ed0159 100644 --- a/tests/pass/arithmetic.glu +++ b/tests/pass/arithmetic.glu @@ -5,6 +5,7 @@ let { Applicative, (*>), ? } = import! std.applicative let int = import! std.int let float = import! std.float let byte @ { ? } = import! std.byte +let { empty } = import! std.monoid let { ? } = import! std.effect @@ -27,6 +28,14 @@ let int_tests = test "from_float" <| \_ -> assert_eq (int.from_float 2.0) 2, test "from_float_truncate" <| \_ -> assert_eq (int.from_float 2.7) 2, test "from_byte" <| \_ -> assert_eq (int.from_byte 2b) 2, + group "monoid" [ + test "additive" <| \_ -> + let { ? } = int.additive + assert_eq 0 empty, + test "multiplicative" <| \_ -> + let { ? } = int.multiplicative + assert_eq 1 empty, + ] ] let float_tests =