diff --git a/check/src/typecheck.rs b/check/src/typecheck.rs index bf68ab88f7..585a89eb6a 100644 --- a/check/src/typecheck.rs +++ b/check/src/typecheck.rs @@ -3,7 +3,6 @@ //! checking of types are done in the `unify_type` and `kindcheck` modules. use std::{ borrow::{BorrowMut, Cow}, - iter::once, mem, ops::Deref, sync::Arc, @@ -1360,25 +1359,25 @@ impl<'a> Typecheck<'a> { let mut return_variables = FnvSet::default(); for arg in &mut **implicit_args { - let arg_ty = self.subs.new_var(); - let ret_ty = self.subs.new_var(); - let f = self - .subs - .function_implicit(once(arg_ty.clone()), ret_ty.clone()); - - self.subsumes(arg.span, ErrorOrder::ExpectedActual, &f, func_type.clone()); + let (arg_typ, ret_typ) = self.subsume_function( + arg.span.start(), + arg.span, + ArgType::Implicit, + func_type.clone(), + &mut Vec::new(), + ); - let arg_ty = self.typecheck(arg, ModType::wobbly(&arg_ty)); + let arg_typ = self.typecheck(arg, ModType::wobbly(&arg_typ)); - if arg_ty.modifier == TypeModifier::Rigid { - types::walk_type(&self.subs.zonk(&arg_ty), &mut |typ: &RcType| { + if arg_typ.modifier == TypeModifier::Rigid { + types::walk_type(&self.subs.zonk(&arg_typ), &mut |typ: &RcType| { if let Type::Variable(var) = &**typ { return_variables.insert(var.id); } }); } - func_type = ret_ty; + func_type = ret_typ; } let mut not_a_function_index = None; @@ -1386,8 +1385,13 @@ impl<'a> Typecheck<'a> { let mut prev_arg_end = implicit_args.last().map_or(span, |arg| arg.span).end(); for arg in args.map(|arg| arg.borrow_mut()) { let errors_before = self.errors.len(); - let (arg_ty, ret_ty) = - self.subsume_function(prev_arg_end, arg.span, func_type.clone(), implicit_args); + let (arg_ty, ret_ty) = self.subsume_function( + prev_arg_end, + arg.span, + ArgType::Explicit, + func_type.clone(), + implicit_args, + ); if errors_before != self.errors.len() { self.errors.pop(); @@ -2754,35 +2758,6 @@ impl<'a> Typecheck<'a> { } } - fn subsume_function( - &mut self, - prev_arg_end: BytePos, - span: Span, - actual: RcType, - implicit_args: &mut Vec>, - ) -> (RcType, RcType) { - let actual = self.remove_aliases(actual); - match actual.as_function_with_type() { - Some((ArgType::Explicit, arg, ret)) => return (arg.clone(), ret.clone()), - _ => (), - } - - let arg_ty = self.subs.new_var(); - let ret_ty = self.subs.new_var(); - let f = self.subs.function(once(arg_ty.clone()), ret_ty.clone()); - - self.subsumes_implicit( - span, - ErrorOrder::ExpectedActual, - &f, - actual, - &mut |implicit_arg| { - implicit_args.push(pos::spanned2(prev_arg_end, span.start(), implicit_arg)); - }, - ); - (arg_ty, ret_ty) - } - fn instantiate_sigma( &mut self, span: Span, @@ -2811,17 +2786,57 @@ impl<'a> Typecheck<'a> { } } + fn subsume_function( + &mut self, + prev_arg_end: BytePos, + span: Span, + arg_type: ArgType, + actual: RcType, + implicit_args: &mut Vec>, + ) -> (RcType, RcType) { + let (_, a, r) = self.merge_function(Some(arg_type), actual, &mut |self_, f, actual| { + self_.subsumes_implicit( + span, + ErrorOrder::ExpectedActual, + &f, + actual, + &mut |implicit_arg| { + implicit_args.push(pos::spanned2(prev_arg_end, span.start(), implicit_arg)); + }, + ); + }); + (a, r) + } + fn unify_function(&mut self, span: Span, actual: RcType) -> (ArgType, RcType, RcType) { + self.merge_function(None, actual, &mut |self_, f, actual| { + self_.unify_span(span, &f, actual); + }) + } + + fn merge_function( + &mut self, + function_arg_type: Option, + actual: RcType, + merge_fn: &mut FnMut(&mut Self, &RcType, RcType), + ) -> (ArgType, RcType, RcType) { let actual = self.remove_aliases(actual); match actual.as_function_with_type() { - Some((arg_type, arg, ret)) => return (arg_type, arg.clone(), ret.clone()), - None => (), + Some((found_arg_type, arg, ret)) + if function_arg_type == Some(found_arg_type) || function_arg_type == None => + { + return (found_arg_type, arg.clone(), ret.clone()) + } + _ => (), } let arg = self.subs.new_var(); let ret = self.subs.new_var(); - let f = self.subs.function(Some(arg.clone()), ret.clone()); - self.unify_span(span, &f, actual); - (ArgType::Explicit, arg, ret) + let arg_type = function_arg_type.unwrap_or(ArgType::Explicit); + let f = self + .subs + .function_type(arg_type, Some(arg.clone()), ret.clone()); + merge_fn(self, &f, actual); + (arg_type, arg, ret) } fn unify_span(&mut self, span: Span, expected: &RcType, actual: RcType) -> RcType {