From e423bcc8d52a34c728d555a302fe72f46bb441b8 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Thu, 18 Jun 2020 10:32:25 -0300 Subject: [PATCH] Compiler: implement autocasting in a better way --- spec/compiler/codegen/automatic_cast.cr | 21 +++++ spec/compiler/semantic/automatic_cast_spec.cr | 29 ++++++ src/compiler/crystal/program.cr | 4 +- src/compiler/crystal/semantic/call.cr | 44 ++++----- .../crystal/semantic/method_lookup.cr | 94 ++++++++++++++----- src/compiler/crystal/semantic/restrictions.cr | 15 +++ 6 files changed, 155 insertions(+), 52 deletions(-) diff --git a/spec/compiler/codegen/automatic_cast.cr b/spec/compiler/codegen/automatic_cast.cr index 0de8d9eef7d4..1d3b0575e1b5 100644 --- a/spec/compiler/codegen/automatic_cast.cr +++ b/spec/compiler/codegen/automatic_cast.cr @@ -239,4 +239,25 @@ describe "Code gen: automatic cast" do foo(1, "a" || 1) )).to_i.should eq(20) end + + it "does multidispatch with automatic casting (3)" do + run(%( + abstract class Foo + end + + class Bar < Foo + def foo(x : UInt8) + 2 + end + end + + class Baz < Foo + def foo(x : UInt8) + 3 + end + end + + Bar.new.as(Foo).foo(1) + )).to_i.should eq(2) + end end diff --git a/spec/compiler/semantic/automatic_cast_spec.cr b/spec/compiler/semantic/automatic_cast_spec.cr index 375379992b21..feed10392f62 100644 --- a/spec/compiler/semantic/automatic_cast_spec.cr +++ b/spec/compiler/semantic/automatic_cast_spec.cr @@ -546,4 +546,33 @@ describe "Semantic: automatic cast" do fill() )) { types["AnotherColor"] } end + + it "doesn't do multidispatch if an overload matches exactly (#8217)" do + assert_type(%( + abstract class Foo + end + + class Bar < Foo + def foo(x : Int64) + x + end + + def foo(*xs : Int64) + xs + end + end + + class Baz < Foo + def foo(x : Int64) + x + end + + def foo(*xs : Int64) + xs + end + end + + Baz.new.as(Foo).foo(1) + )) { int64 } + end end diff --git a/src/compiler/crystal/program.cr b/src/compiler/crystal/program.cr index cd576c8f57bd..1730958be238 100644 --- a/src/compiler/crystal/program.cr +++ b/src/compiler/crystal/program.cr @@ -581,8 +581,8 @@ module Crystal end end - def lookup_private_matches(filename, signature) - file_module?(filename).try &.lookup_matches(signature) + def lookup_private_matches(filename, signature, analyze_all = false) + file_module?(filename).try &.lookup_matches(signature, analyze_all: analyze_all) end def file_module?(filename) diff --git a/src/compiler/crystal/semantic/call.cr b/src/compiler/crystal/semantic/call.cr index 33808a8427ac..3c69bfb6ecac 100644 --- a/src/compiler/crystal/semantic/call.cr +++ b/src/compiler/crystal/semantic/call.cr @@ -247,7 +247,7 @@ class Crystal::Call matches = lookup_matches_checking_expansion(owner, signature, with_literals: with_literals) if matches.empty? && owner.class? && owner.abstract? - matches = owner.virtual_type.lookup_matches(signature) + matches = owner.virtual_type.lookup_matches(signature, analyze_all: with_literals) end if matches.empty? @@ -263,27 +263,27 @@ class Crystal::Call signature = CallSignature.new(def_name, arg_types, block, named_args_types) matches = check_tuple_indexer(owner, def_name, args, arg_types) - matches ||= lookup_matches_checking_expansion(owner, signature, search_in_parents) + matches ||= lookup_matches_checking_expansion(owner, signature, search_in_parents, with_literals: with_literals) # If we didn't find a match and this call doesn't have a receiver, # and we are not at the top level, let's try searching the top-level if matches.empty? && !obj && owner != program && search_in_toplevel - program_matches = lookup_matches_with_signature(program, signature, search_in_parents) + program_matches = lookup_matches_with_signature(program, signature, search_in_parents, with_literals) matches = program_matches unless program_matches.empty? end if matches.empty? && owner.class? && owner.abstract? && name != "super" - matches = owner.virtual_type.lookup_matches(signature) + matches = owner.virtual_type.lookup_matches(signature, analyze_all: with_literals) end if matches.empty? defined_method_missing = owner.check_method_missing(signature, self) if defined_method_missing - matches = owner.lookup_matches(signature) + matches = owner.lookup_matches(signature, analyze_all: with_literals) elsif with_scope = @with_scope defined_method_missing = with_scope.check_method_missing(signature, self) if defined_method_missing - matches = with_scope.lookup_matches(signature) + matches = with_scope.lookup_matches(signature, analyze_all: with_literals) @uses_with_scope = true end end @@ -321,9 +321,9 @@ class Crystal::Call matches = bubbling_exception do target = parent_visitor.typed_def.original_owner if search_in_parents - target.lookup_matches signature + target.lookup_matches(signature, analyze_all: with_literals) else - target.lookup_matches_without_parents signature + target.lookup_matches_without_parents(signature, analyze_all: with_literals) end end matches.each do |match| @@ -332,48 +332,36 @@ class Crystal::Call end matches else - bubbling_exception { lookup_matches_with_signature(owner, signature, search_in_parents) } + bubbling_exception { lookup_matches_with_signature(owner, signature, search_in_parents, with_literals) } end end - def lookup_matches_with_signature(owner : Program, signature, search_in_parents) + def lookup_matches_with_signature(owner : Program, signature, search_in_parents, with_literals) location = self.location if location && (filename = location.original_filename) - matches = owner.lookup_private_matches filename, signature + matches = owner.lookup_private_matches(filename, signature, analyze_all: with_literals) end if matches if matches.empty? - matches = owner.lookup_matches signature + matches = owner.lookup_matches(signature, analyze_all: with_literals) end else - matches = owner.lookup_matches signature + matches = owner.lookup_matches(signature, analyze_all: with_literals) end matches end - def lookup_matches_with_signature(owner, signature, search_in_parents) + def lookup_matches_with_signature(owner, signature, search_in_parents, with_literals) if search_in_parents - owner.lookup_matches signature + owner.lookup_matches(signature, analyze_all: with_literals) else - owner.lookup_matches_without_parents signature + owner.lookup_matches_without_parents(signature, analyze_all: with_literals) end end def instantiate(signature, matches, owner, self_type, with_literals) - if with_literals - # Now that we have all our matches, check if any of them matches exactly - # all types, assuming autocasted values will always match (because they - # matches and they were not ambiguous). If so, only keep matches up to - # that exact match. We need to do this here because with autocasting - # we consider all overloads to detect ambiguous usage. - stop_index = matches.index do |match| - signature.matches_exactly?(match, with_literals: true) - end - matches = matches[..stop_index] if stop_index - end - matches.each &.remove_literals if with_literals block = @block diff --git a/src/compiler/crystal/semantic/method_lookup.cr b/src/compiler/crystal/semantic/method_lookup.cr index 0b13572ec657..78569c9d1a9a 100644 --- a/src/compiler/crystal/semantic/method_lookup.cr +++ b/src/compiler/crystal/semantic/method_lookup.cr @@ -1,5 +1,55 @@ require "../types" +# Looking up matches involves two steps: +# +# 1. Lookup is done with autocasting disabled. +# +# In this scenario as soon as we find an exact match we don't look at other +# overloads because the exact match will prevent them from being considered. +# +# If no matches are found we try again but this time with autocasting enabled. +# In `semantic/call.cr` this is when `with_literals` is `true`, and this is when +# `analyze_all` will be `true` here. +# +# 2. Lookup is done with autocasting enabled. +# +# In this mode the types for NumberLiteral and SymbolLiteral are not the usual +# types but instead the special NumberLiteralType and SymbolLiteralType. +# +# In this mode we also need to stop as soon as we find an exact match +# (which just means when the first overload matches with autocasting, which +# is for example when passing 1 to an Int64 restriction) but we still need +# to analyze all possible methods in case there's an ambiguity. For example: +# +# ``` +# def foo(x : Int64) +# end +# +# def foo(x : Int8) +# end +# +# foo(1) +# ``` +# +# In the example above we can't just stop at the first overload because +# we need to analyze the second overload to find out that the call is ambiguous. +# +# However, consider this: +# +# ``` +# def foo(x : Int64) +# end +# +# def foo(x : *Int64) +# end +# +# foo(1) +# ``` +# +# In this case there's no ambiguity: 1 means `Int64`. However, the first overload +# is an exact match and there's no need to consider the second overload in the +# multidispatch. However, we do need to analyze it to check if there's an ambiguity. + module Crystal record NamedArgumentType, name : String, type : Type do def self.from_args(named_args : Array(NamedArgument)?, with_literals = false) @@ -16,8 +66,8 @@ module Crystal named_args : Array(NamedArgumentType)? class Type - def lookup_matches(signature, owner = self, path_lookup = self, matches_array = nil) - matches = lookup_matches_without_parents(signature, owner, path_lookup, matches_array) + def lookup_matches(signature, owner = self, path_lookup = self, matches_array = nil, analyze_all = false) + matches = lookup_matches_without_parents(signature, owner, path_lookup, matches_array, analyze_all: analyze_all) return matches if matches.cover_all? matches_array = matches.matches @@ -38,7 +88,7 @@ module Crystal # and can be known by invoking `lookup_new_in_ancestors?` if my_parents && !(is_new && !lookup_new_in_ancestors?) my_parents.each do |parent| - matches = parent.lookup_matches(signature, owner, parent, matches_array) + matches = parent.lookup_matches(signature, owner, parent, matches_array, analyze_all: analyze_all) if matches.cover_all? return matches else @@ -55,10 +105,12 @@ module Crystal Matches.new(matches_array, cover, owner, false) end - def lookup_matches_without_parents(signature, owner = self, path_lookup = self, matches_array = nil) + def lookup_matches_without_parents(signature, owner = self, path_lookup = self, matches_array = nil, analyze_all = false) if defs = self.defs.try &.[signature.name]? context = MatchContext.new(owner, path_lookup) + exact_match = nil + defs.each do |item| next if item.def.abstract? @@ -72,6 +124,8 @@ module Crystal match = signature.match(item, context) + next if exact_match + if match matches_array ||= [] of Match matches_array << match @@ -81,7 +135,8 @@ module Crystal # a function type with return T can be transpass a restriction of a function # with the same arguments but which returns Void. if signature.matches_exactly?(match) - return Matches.new(matches_array, true, owner) + exact_match = Matches.new(matches_array, true, owner) + break unless analyze_all end context = MatchContext.new(owner, path_lookup) @@ -90,13 +145,17 @@ module Crystal context.def_free_vars = nil end end + + if exact_match + return exact_match + end end Matches.new(matches_array, Cover.create(signature, matches_array), owner) end - def lookup_matches_with_modules(signature, owner = self, path_lookup = self, matches_array = nil) - matches = lookup_matches_without_parents(signature, owner, path_lookup, matches_array) + def lookup_matches_with_modules(signature, owner = self, path_lookup = self, matches_array = nil, analyze_all = false) + matches = lookup_matches_without_parents(signature, owner, path_lookup, matches_array, analyze_all: analyze_all) return matches unless matches.empty? is_new = owner.metaclass? && signature.name == "new" @@ -115,7 +174,7 @@ module Crystal my_parents.each do |parent| break unless parent.module? - matches = parent.lookup_matches_with_modules(signature, owner, parent, matches_array) + matches = parent.lookup_matches_with_modules(signature, owner, parent, matches_array, analyze_all: analyze_all) return matches unless matches.empty? end end @@ -304,10 +363,6 @@ module Crystal def matches_exactly?(match : Match, *, with_literals : Bool = false) arg_types_equal = self.arg_types.equals?(match.arg_types) do |x, y| - if with_literals && x.is_a?(LiteralType) - x = x.match || x.remove_literal - end - x.compatible_with?(y) end if (match_named_args = match.named_arg_types) && (signature_named_args = self.named_args) && @@ -315,12 +370,7 @@ module Crystal match_named_args = match_named_args.sort_by &.name signature_named_args = signature_named_args.sort_by &.name named_arg_types_equal = signature_named_args.equals?(match_named_args) do |x, y| - x_type = x.type - if with_literals && x_type.is_a?(LiteralType) - x_type = x_type.match || x_type.remove_literal - end - - x.name == y.name && x_type.compatible_with?(y.type) + x.name == y.name && x.type.compatible_with?(y.type) end else named_arg_types_equal = !match.named_arg_types && !self.named_args @@ -341,11 +391,11 @@ module Crystal type end - def lookup_matches(signature, owner = self, path_lookup = self) + def lookup_matches(signature, owner = self, path_lookup = self, analyze_all = false) is_new = virtual_metaclass? && signature.name == "new" base_type_lookup = virtual_lookup(base_type) - base_type_matches = base_type_lookup.lookup_matches(signature, self) + base_type_matches = base_type_lookup.lookup_matches(signature, self, analyze_all: analyze_all) # If there are no subclasses no need to look further if leaf? @@ -369,7 +419,7 @@ module Crystal subtype_virtual_lookup = virtual_lookup(subtype.virtual_type) # Check matches but without parents: only included modules - subtype_matches = subtype_lookup.lookup_matches_with_modules(signature, subtype_virtual_lookup, subtype_virtual_lookup) + subtype_matches = subtype_lookup.lookup_matches_with_modules(signature, subtype_virtual_lookup, subtype_virtual_lookup, analyze_all: analyze_all) # For Foo+.class#new we need to check that this subtype doesn't define # an incompatible initialize: if so, we return empty matches, because @@ -390,7 +440,7 @@ module Crystal base_type_matches.each do |base_type_match| if base_type_match.def.macro_def? # We need to copy each submatch if it's a macro def - full_subtype_matches = subtype_lookup.lookup_matches(signature, subtype_virtual_lookup, subtype_virtual_lookup) + full_subtype_matches = subtype_lookup.lookup_matches(signature, subtype_virtual_lookup, subtype_virtual_lookup, analyze_all: analyze_all) full_subtype_matches.each do |full_subtype_match| cloned_def = full_subtype_match.def.clone cloned_def.macro_owner = full_subtype_match.def.macro_owner diff --git a/src/compiler/crystal/semantic/restrictions.cr b/src/compiler/crystal/semantic/restrictions.cr index 1e862e60c1b7..d6049159304f 100644 --- a/src/compiler/crystal/semantic/restrictions.cr +++ b/src/compiler/crystal/semantic/restrictions.cr @@ -1240,6 +1240,10 @@ module Crystal type end end + + def compatible_with?(type) + literal.type == type || literal.can_be_autocast_to?(type) + end end class SymbolLiteralType @@ -1264,6 +1268,17 @@ module Crystal type end end + + def compatible_with?(type) + case type + when SymbolType + true + when EnumType + !!(type.find_member(literal.value)) + else + false + end + end end end