diff --git a/spec/compiler/semantic/macro_spec.cr b/spec/compiler/semantic/macro_spec.cr index fdc71fdfc01f..7868840cc73a 100644 --- a/spec/compiler/semantic/macro_spec.cr +++ b/spec/compiler/semantic/macro_spec.cr @@ -775,6 +775,22 @@ describe "Semantic: macro" do )) { int32.metaclass } end + it "finds generic type argument of included module with self" do + assert_type(%( + module Bar(T) + def t + {{ T }} + end + end + + class Foo(U) + include Bar(self) + end + + Foo(Int32).new.t + )) { generic_class("Foo", int32).metaclass } + end + it "finds free type vars" do assert_type(%( module Foo(T) diff --git a/spec/compiler/semantic/module_spec.cr b/spec/compiler/semantic/module_spec.cr index 35c49a9d5d9b..e5fc5cab4c18 100644 --- a/spec/compiler/semantic/module_spec.cr +++ b/spec/compiler/semantic/module_spec.cr @@ -226,7 +226,115 @@ describe "Semantic: module" do end end - class Baz(X) + class Bar(U) + include Foo(self) + end + + Bar(Int32).new.foo + ") { generic_class("Bar", int32).metaclass } + end + + it "includes generic module with self, and inherits it" do + assert_type(" + module Foo(T) + def foo + T + end + end + + class Bar(U) + include Foo(self) + end + + class Baz < Bar(Int32) + end + + Baz.new.foo + ") { types["Baz"].metaclass } + end + + it "includes generic module with self (check argument type, success)" do + assert_type(" + module Foo(T) + def foo(x : T) + x + end + end + + class Bar(U) + include Foo(self) + end + + Bar(Int32).new.foo Bar(Int32).new + ") { generic_class("Bar", int32) } + end + + it "includes generic module with self (check argument superclass type, success)" do + assert_type(" + module Foo(T) + def foo(x : T) + x + end + end + + class Bar(U) + include Foo(self) + end + + class Baz < Bar(Int32) + end + + Bar(Int32).new.foo Baz.new + ") { types["Baz"] } + end + + it "includes generic module with self (check argument type, error)" do + assert_error " + module Foo(T) + def foo(x : T) + x + end + end + + class Bar(U) + include Foo(self) + end + + class Baz1 < Bar(Int32) + end + + class Baz2 < Bar(Int32) + end + + Baz1.new.foo Baz2.new + ", "no overload matches" + end + + it "includes generic module with self (check argument superclass type, error)" do + assert_error " + module Foo(T) + def foo(x : T) + x + end + end + + class Bar(U) + include Foo(self) + end + + class Baz < Bar(Int32) + end + + Baz.new.foo Bar(Int32).new + ", "no overload matches" + end + + it "includes generic module with self (check return type, success)" do + assert_type(" + module Foo(T) + def foo : T + Bar(Int32).new + end end class Bar(U) @@ -234,7 +342,67 @@ describe "Semantic: module" do end Bar(Int32).new.foo - ") { generic_class("Bar", int32).metaclass } + ") { generic_class("Bar", int32) } + end + + it "includes generic module with self (check return subclass type, success)" do + assert_type(" + module Foo(T) + def foo : T + Baz.new + end + end + + class Bar(U) + include Foo(self) + end + + class Baz < Bar(Int32) + end + + Bar(Int32).new.foo + ") { types["Baz"] } + end + + it "includes generic module with self (check return type, error)" do + assert_error " + module Foo(T) + def foo : T + Bar(Int32).new + end + end + + class Bar(U) + include Foo(self) + end + + class Baz < Bar(Int32) + end + + Baz.new.foo + ", "type must be Baz, not Bar(Int32)" + end + + it "includes generic module with self (check return subclass type, error)" do + assert_error " + module Foo(T) + def foo : T + Baz2.new + end + end + + class Bar(U) + include Foo(self) + end + + class Baz1 < Bar(Int32) + end + + class Baz2 < Bar(Int32) + end + + Baz1.new.foo + ", "type must be Baz1, not Baz2" end it "includes module but can't access metaclass methods" do diff --git a/src/compiler/crystal/macros/interpreter.cr b/src/compiler/crystal/macros/interpreter.cr index feb183050564..fc39e4999383 100644 --- a/src/compiler/crystal/macros/interpreter.cr +++ b/src/compiler/crystal/macros/interpreter.cr @@ -418,6 +418,9 @@ module Crystal end TypeNode.new(matched_type) + when Self + target = @scope == @program.class_type ? @scope : @scope.instance_type + TypeNode.new(target) when ASTNode matched_type else diff --git a/src/compiler/crystal/semantic/main_visitor.cr b/src/compiler/crystal/semantic/main_visitor.cr index 239944fb3126..0b832baa3537 100644 --- a/src/compiler/crystal/semantic/main_visitor.cr +++ b/src/compiler/crystal/semantic/main_visitor.cr @@ -167,6 +167,8 @@ module Crystal node.bind_to type.value when Type node.type = check_type_in_type_args(type.remove_alias_if_simple) + when Self + node.type = check_type_in_type_args(the_self(node).remove_alias_if_simple) when ASTNode type.accept self unless type.type? node.syntax_replacement = type @@ -300,12 +302,7 @@ module Crystal end def visit(node : Self) - the_self = (@scope || current_type) - if the_self.is_a?(Program) - node.raise "there's no self in this scope" - end - - node.type = the_self.instance_type + node.type = the_self(node).instance_type end def visit(node : Var) @@ -3100,6 +3097,14 @@ module Crystal end end + def the_self(node) + the_self = (@scope || current_type) + if the_self.is_a?(Program) + node.raise "there's no self in this scope" + end + the_self + end + def visit(node : When | Unless | Until | MacroLiteral | OpAssign) raise "Bug: #{node.class_desc} node '#{node}' (#{node.location}) should have been eliminated in normalize" end diff --git a/src/compiler/crystal/semantic/semantic_visitor.cr b/src/compiler/crystal/semantic/semantic_visitor.cr index 22ebbf685882..179d88886bfc 100644 --- a/src/compiler/crystal/semantic/semantic_visitor.cr +++ b/src/compiler/crystal/semantic/semantic_visitor.cr @@ -214,8 +214,8 @@ abstract class Crystal::SemanticVisitor < Crystal::Visitor end end - def lookup_type(node : ASTNode, free_vars = nil) - current_type.lookup_type(node, free_vars: free_vars, allow_typeof: false) + def lookup_type(node : ASTNode, free_vars = nil, lazy_self = false) + current_type.lookup_type(node, free_vars: free_vars, allow_typeof: false, lazy_self: lazy_self) end def check_outside_exp(node, op) diff --git a/src/compiler/crystal/semantic/top_level_visitor.cr b/src/compiler/crystal/semantic/top_level_visitor.cr index bda70cf17b4c..0a0e986e9be9 100644 --- a/src/compiler/crystal/semantic/top_level_visitor.cr +++ b/src/compiler/crystal/semantic/top_level_visitor.cr @@ -811,7 +811,7 @@ class Crystal::TopLevelVisitor < Crystal::SemanticVisitor def include_in(current_type, node, kind) node_name = node.name - type = lookup_type(node_name) + type = lookup_type(node_name, lazy_self: true) case type when GenericModuleType node.raise "wrong number of type vars for #{type} (given 0, expected #{type.type_vars.size})" diff --git a/src/compiler/crystal/semantic/type_lookup.cr b/src/compiler/crystal/semantic/type_lookup.cr index ed6c23025cf4..26c2d61a54cd 100644 --- a/src/compiler/crystal/semantic/type_lookup.cr +++ b/src/compiler/crystal/semantic/type_lookup.cr @@ -38,28 +38,28 @@ class Crystal::Type # ``` # # If `self` is `Foo` and `Bar(Baz)` is given, the result will be `Foo::Bar(Baz)`. - def lookup_type(node : ASTNode, self_type = self.instance_type, allow_typeof = true, free_vars : Hash(String, TypeVar)? = nil) : Type - TypeLookup.new(self, self_type, true, allow_typeof, free_vars).lookup(node).not_nil! + def lookup_type(node : ASTNode, self_type = self.instance_type, allow_typeof = true, lazy_self = false, free_vars : Hash(String, TypeVar)? = nil) : Type + TypeLookup.new(self, self_type, true, allow_typeof, lazy_self, free_vars).lookup(node).not_nil! end # Similar to `lookup_type`, but returns `nil` if a type can't be found. - def lookup_type?(node : ASTNode, self_type = self.instance_type, allow_typeof = true, free_vars : Hash(String, TypeVar)? = nil) : Type? - TypeLookup.new(self, self_type, false, allow_typeof, free_vars).lookup(node) + def lookup_type?(node : ASTNode, self_type = self.instance_type, allow_typeof = true, lazy_self = false, free_vars : Hash(String, TypeVar)? = nil) : Type? + TypeLookup.new(self, self_type, false, allow_typeof, lazy_self, free_vars).lookup(node) end # Similar to `lookup_type`, but the result might also be an ASTNode, for example when # looking `N` relative to a StaticArray. def lookup_type_var(node : Path, free_vars : Hash(String, TypeVar)? = nil) : Type | ASTNode - TypeLookup.new(self, self.instance_type, true, false, free_vars).lookup_type_var(node).not_nil! + TypeLookup.new(self, self.instance_type, true, false, false, free_vars).lookup_type_var(node).not_nil! end # Similar to `lookup_type_var`, but might return `nil`. def lookup_type_var?(node : Path, free_vars : Hash(String, TypeVar)? = nil, raise = false) : Type | ASTNode | Nil - TypeLookup.new(self, self.instance_type, raise, false, free_vars).lookup_type_var?(node) + TypeLookup.new(self, self.instance_type, raise, false, false, free_vars).lookup_type_var?(node) end private struct TypeLookup - def initialize(@root : Type, @self_type : Type, @raise : Bool, @allow_typeof : Bool, @free_vars : Hash(String, TypeVar)? = nil) + def initialize(@root : Type, @self_type : Type, @raise : Bool, @allow_typeof : Bool, @lazy_self : Bool, @free_vars : Hash(String, TypeVar)? = nil) @in_generic_args = 0 # If we are looking types inside a non-instantiated generic type, @@ -84,6 +84,8 @@ class Crystal::Type node.raise "#{type_var} is not a type, it's a constant" when Type return type_var + when Self + return lookup(type_var) end if @raise @@ -211,8 +213,14 @@ class Crystal::Type type_vars = Array(TypeVar).new(node.type_vars.size + 1) node.type_vars.each do |type_var| case type_var + when Self + if @lazy_self + type_vars << type_var + next + end when NumberLiteral type_vars << type_var + next when Splat type = in_generic_args { lookup(type_var.exp) } return if !@raise && !type @@ -230,37 +238,38 @@ class Crystal::Type type_var.raise "can only splat tuple type, not #{splat_type}" end - else - # Check the case of T resolving to a number - if type_var.is_a?(Path) && type_var.names.size == 1 - type = @root.lookup_path(type_var) - case type - when Const - interpreter = MathInterpreter.new(@root) - begin - num = interpreter.interpret(type.value) - type_vars << NumberLiteral.new(num) - rescue ex : Crystal::Exception - type_var.raise "expanding constant value for a number value", inner: ex - end - next - when ASTNode - type_vars << type - next + next + end + + # Check the case of T resolving to a number + if type_var.is_a?(Path) && type_var.names.size == 1 + type = @root.lookup_path(type_var) + case type + when Const + interpreter = MathInterpreter.new(@root) + begin + num = interpreter.interpret(type.value) + type_vars << NumberLiteral.new(num) + rescue ex : Crystal::Exception + type_var.raise "expanding constant value for a number value", inner: ex end + next + # when ASTNode + # type_vars << type + # next end + end - type = in_generic_args { lookup(type_var) } - return if !@raise && !type - type = type.not_nil! - - case instance_type - when GenericUnionType, PointerType, StaticArrayType, TupleType, ProcType - check_type_allowed_in_generics(type_var, type, "can't use #{type} as a generic type argument") - end + type = in_generic_args { lookup(type_var) } + return if !@raise && !type + type = type.not_nil! - type_vars << type.virtual_type + case instance_type + when GenericUnionType, PointerType, StaticArrayType, TupleType, ProcType + check_type_allowed_in_generics(type_var, type, "can't use #{type} as a generic type argument") end + + type_vars << type.virtual_type end begin