diff --git a/spec/compiler/codegen/macro_spec.cr b/spec/compiler/codegen/macro_spec.cr index 9f3e6e1211d1..f789887a321d 100644 --- a/spec/compiler/codegen/macro_spec.cr +++ b/spec/compiler/codegen/macro_spec.cr @@ -1849,4 +1849,68 @@ describe "Code gen: macro" do (Foo.new || Bar.new).foo )).to_string.should eq("Foo") end + + it "invokes macro method inside Crystal::Macros module" do + run(%( + module Crystal::Macros + macro foo(x) + x + "bar" + end + end + + macro bar(x) + {{ foo(x) }} + end + + bar("foo") + ), inject_primitives: false).to_string.should eq("foobar") + end + + it "invokes macro method of ASTNode" do + run(%( + class Crystal::Macros::StringLiteral + macro plus_bar + self + "bar" + end + end + + macro bar(x) + {{ x.plus_bar }} + end + + bar("foo") + ), inject_primitives: false).to_string.should eq("foobar") + end + + it "invokes macro method of any type" do + run(%( + module Foo + macro foo(x) + x + "bar" + end + end + + macro bar(x) + {{ Foo.foo(x) }} + end + + bar("foo") + ), inject_primitives: false).to_string.should eq("foobar") + end + + it "invokes macro method of any type, with return" do + run(%( + module Foo + macro foo(x) + return x + "bar" + end + end + + macro bar(x) + {{ Foo.foo(x) }} + end + + bar("foo") + ), inject_primitives: false).to_string.should eq("foobar") + end end diff --git a/src/compiler/crystal/macros/interpreter.cr b/src/compiler/crystal/macros/interpreter.cr index db5bcceff172..6d45c18ae050 100644 --- a/src/compiler/crystal/macros/interpreter.cr +++ b/src/compiler/crystal/macros/interpreter.cr @@ -74,6 +74,9 @@ module Crystal record MacroVarKey, name : String, exps : Array(ASTNode)? + getter program + property? macro_method_mode = false + def initialize(@program : Program, @scope : Type, @path_lookup : Type, @location : Location?, @vars = {} of String => ASTNode, @block : Block? = nil, @def : Def? = nil, @@ -581,7 +584,47 @@ module Crystal false end + def visit(node : Return) + if macro_method_mode? + exp = node.exp + if exp + exp.accept self + else + @last = NilLiteral.new + end + else + cant_execute(node) + end + end + + def visit(node : While) + while true + node.cond.accept self + break if !@last.truthy? + node.body.accept self + end + false + end + + def visit(node : Until) + while true + node.cond.accept self + break if @last.truthy? + node.body.accept self + end + false + end + + def visit(node : OpAssign) + @program.normalize(node).accept self + false + end + def visit(node : ASTNode) + cant_execute(node) + end + + def cant_execute(node) node.raise "can't execute #{node.class_desc} in a macro" end diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index 4bf8f5e91560..c8f6d38e65cf 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -42,6 +42,36 @@ module Crystal node.raise("undefined macro method: '#{node.name}'") end + def interpret_call_inside_type(owner, name, args, named_args, block, self self_value = nil) + if named_args.is_a?(Hash) + named_args = named_args.map do |name, value| + NamedArgument.new(name, value) + end + end + + matching_macro = owner.lookup_macro(name, args, named_args) + return unless matching_macro.is_a?(Macro) + + interpreter = MacroInterpreter.new( + @program, + @scope, + @program.crystal, + matching_macro, + Call.new( + obj: nil, + name: name, + args: args, + named_args: named_args, + block: block), + nil, + true) + interpreter.macro_method_mode = true + interpreter.define_var("self", self_value) if self_value + + interpreter.accept(matching_macro.parsed_body) + @last = interpreter.last + end + def interpret_top_level_call?(node) # Please order method names in lexicographical order case node.name @@ -72,8 +102,27 @@ module Crystal when "run" interpret_run(node) else - nil + interpret_user_defined_top_level_call(node) + end + end + + def interpret_user_defined_top_level_call(node) + args = node.args.map do |arg| + accept(arg) + @last + end + + named_args = node.named_args.try &.map do |named_arg| + accept(named_arg.value) + NamedArgument.new(named_arg.name, @last) end + + # Top-levels macro calls are looked up inside the Crystal::Macros module, + # just like built-in top-level macro methods. + macros = @program.crystal.lookup_type?(Path.new("Macros")) + return unless macros + + interpret_call_inside_type(macros, node.name, args, named_args, node.block) end def interpret_compare_versions(node) @@ -381,6 +430,14 @@ module Crystal when "!" BoolLiteral.new(!truthy?) else + # Try to lookup a type for this node, for example Crystal::Macros::StringLiteral + owner = interpreter.program.crystal.lookup_type?(Path.new(["Macros", class_desc])) + if owner + # Then try to lookup a user-defined macro in that type + value = interpreter.interpret_call_inside_type(owner, method, args, named_args, block, self: self) + return value if value + end + raise "undefined macro method '#{class_desc}##{method}'", exception_type: Crystal::UndefinedMacroMethodError end end @@ -1424,6 +1481,45 @@ module Crystal super end end + + property(parsed_body : ASTNode) do + vars = Set(String).new + args.each do |arg| + vars << arg.name + end + + gatherer = MacroLiteralGatherer.new + body.accept gatherer + gathered_body = gatherer.to_s + + Parser.parse(gathered_body, def_vars: [vars]) + end + + class MacroLiteralGatherer < Visitor + def initialize + @io = String::Builder.new + end + + def visit(node : Expressions) + node.expressions.each do |exp| + exp.accept self + end + false + end + + def visit(node : MacroLiteral) + @io << node.value + false + end + + def visit(node) + node.raise "Can't use #{node.class} inside macro methods" + end + + def to_s + @io.to_s + end + end end class UnaryExpression @@ -1731,6 +1827,10 @@ module Crystal when "resolve?" interpret_argless_method(method, args) { self } else + # Lookup a user-defined macro method inside the type + value = interpreter.interpret_call_inside_type(type, method, args, named_args, block) + return value if value + super end end diff --git a/src/compiler/crystal/program.cr b/src/compiler/crystal/program.cr index 1730958be238..8f273f405f83 100644 --- a/src/compiler/crystal/program.cr +++ b/src/compiler/crystal/program.cr @@ -205,6 +205,8 @@ module Crystal types["Proc"] = @proc = ProcType.new self, self, "Proc", value, ["T", "R"] types["Union"] = @union = GenericUnionType.new self, self, "Union", value, ["T"] types["Crystal"] = @crystal = NonGenericModuleType.new self, self, "Crystal" + crystal.types["Macros"] = macros = NonGenericModuleType.new self, crystal, "Macros" + define_crystal_macros_ast_nodes(macros) types["ARGC_UNSAFE"] = @argc = argc_unsafe = Const.new self, self, "ARGC_UNSAFE", Primitive.new("argc", int32) types["ARGV_UNSAFE"] = @argv = argv_unsafe = Const.new self, self, "ARGV_UNSAFE", Primitive.new("argv", pointer_of(pointer_of(uint8))) @@ -278,6 +280,22 @@ module Crystal crystal.types[name] = Const.new self, crystal, name, value end + private def define_crystal_macros_ast_nodes(macros) + macros.types["ASTNode"] = ast_node = NonGenericClassType.new self, macros, "ASTNode", reference + + %w(Annotation Arg ArrayLiteral Assign BinaryOp Block BoolLiteral + Call Case Cast CharLiteral ClassDef ClassVar Def Expressions + Generic Global HashLiteral If ImplicitObj InstanceVar IsA Macro + MacroId MetaVar MultiAssign NamedArgument NamedTupleLiteral NilableCast + NilLiteral Nop NumberLiteral OffsetOf Path ProcLiteral ProcNotation ProcPointer + RangeLiteral ReadInstanceVar RegexLiteral Require RespondsTo Splat + StringInterpolation StringLiteral SymbolLiteral TupleLiteral TypeDeclaration + TypeNode UnaryExpression UninitializedVar Union Var VisibilityModifier + When While).each do |name| + macros.types[name] = NonGenericClassType.new self, macros, name, ast_node + end + end + property(target_machine : LLVM::TargetMachine) { codegen_target.to_target_machine } # Returns the `Type` for `Array(type)`