diff --git a/spec/compiler/normalize/case_spec.cr b/spec/compiler/normalize/case_spec.cr index d47820b639ce..cee825e818b5 100644 --- a/spec/compiler/normalize/case_spec.cr +++ b/spec/compiler/normalize/case_spec.cr @@ -53,6 +53,10 @@ describe "Normalize: case" do assert_expand "case x = 1; when 2; 3; end", "x = 1\nif 2 === x\n 3\nend" end + it "normalizes case with assignment wrapped by paren" do + assert_expand "case (x = 1); when 2; 3; end", "x = 1\nif 2 === x\n 3\nend" + end + it "normalizes case without value" do assert_expand "case when 2; 3; when 4; 5; end", "if 2\n 3\nelse\n if 4\n 5\n end\nend" end diff --git a/src/compiler/crystal/semantic/cleanup_transformer.cr b/src/compiler/crystal/semantic/cleanup_transformer.cr index 74ffcde7d6b0..cd6b402fe3ec 100644 --- a/src/compiler/crystal/semantic/cleanup_transformer.cr +++ b/src/compiler/crystal/semantic/cleanup_transformer.cr @@ -128,6 +128,10 @@ module Crystal end def transform(node : Expressions) + if exp = node.single_expression? + return exp.transform(self) + end + exps = [] of ASTNode node.expressions.each_with_index do |exp, i| @@ -144,6 +148,7 @@ module Crystal end def flatten_collect(exp, exps) + exp = exp.single_expression if exp.is_a?(Expressions) exp.expressions.each do |subexp| return true if flatten_collect(subexp, exps) diff --git a/src/compiler/crystal/semantic/literal_expander.cr b/src/compiler/crystal/semantic/literal_expander.cr index 71c78c87b757..f06f49e91a8d 100644 --- a/src/compiler/crystal/semantic/literal_expander.cr +++ b/src/compiler/crystal/semantic/literal_expander.cr @@ -208,11 +208,7 @@ module Crystal # temp # end def expand(node : And) - left = node.left - - if left.is_a?(Expressions) && left.expressions.size == 1 - left = left.expressions.first - end + left = node.left.single_expression new_node = if left.is_a?(Var) || (left.is_a?(IsA) && left.obj.is_a?(Var)) If.new(left, node.right, left.clone) @@ -245,11 +241,7 @@ module Crystal # b # end def expand(node : Or) - left = node.left - - if left.is_a?(Expressions) && left.expressions.size == 1 - left = left.expressions.first - end + left = node.left.single_expression new_node = if left.is_a?(Var) || (left.is_a?(IsA) && left.obj.is_a?(Var)) If.new(left, left.clone, node.right) @@ -390,7 +382,7 @@ module Crystal assigns = [] of ASTNode temp_vars = conds.map do |cond| - case cond + case cond = cond.single_expression when Var, InstanceVar temp_var = cond when Assign diff --git a/src/compiler/crystal/semantic/main_visitor.cr b/src/compiler/crystal/semantic/main_visitor.cr index fdfe9ffe1124..a2dc89f5220d 100644 --- a/src/compiler/crystal/semantic/main_visitor.cr +++ b/src/compiler/crystal/semantic/main_visitor.cr @@ -1742,7 +1742,7 @@ module Crystal target = exp.target return target if target.is_a?(Var) when Expressions - return unless exp = single_expression(exp) + return unless exp = exp.single_expression? return get_expression_var(exp) end nil @@ -1850,13 +1850,13 @@ module Crystal # block is when the condition is a Var (in the else it must be # nil), IsA (in the else it's not that type), RespondsTo # (in the else it doesn't respond to that message) or Not. - case cond = single_expression(node.cond) || node.cond + case cond = node.cond.single_expression when Var, IsA, RespondsTo, Not filter_vars cond_type_filters, &.not when Or # Try to apply boolean logic: `!(a || b)` is `!a && !b` - cond_left = single_expression(cond.left) || cond.left - cond_right = single_expression(cond.right) || cond.right + cond_left = cond.left.single_expression + cond_right = cond.right.single_expression # We can't deduce anything for sub && or || expressions or_left_type_filters = nil if cond_left.is_a?(And) || cond_left.is_a?(Or) @@ -2026,7 +2026,7 @@ module Crystal node.body.accept self end - cond = single_expression(node.cond) || node.cond + cond = node.cond.single_expression endless_while = cond.true_literal? merge_while_vars cond, endless_while, before_cond_vars_copy, before_cond_vars, after_cond_vars, @vars, node.break_vars @@ -2158,7 +2158,7 @@ module Crystal when Call return get_while_cond_assign_target(node.obj) when Expressions - return unless node = single_expression(node) + return unless node = node.single_expression? return get_while_cond_assign_target(node) end @@ -2192,16 +2192,6 @@ module Crystal end end - def single_expression(node) - result = nil - - while node.is_a?(Expressions) && node.expressions.size == 1 - result = node = node[0] - end - - result - end - def end_visit(node : Break) if last_block_kind == :ensure node.raise "can't use break inside ensure" diff --git a/src/compiler/crystal/syntax/ast.cr b/src/compiler/crystal/syntax/ast.cr index a4bbd4e677af..e56568ecf32a 100644 --- a/src/compiler/crystal/syntax/ast.cr +++ b/src/compiler/crystal/syntax/ast.cr @@ -89,6 +89,18 @@ module Crystal def pretty_print(pp) pp.text to_s end + + # It yields itself for any node, but `Expressions` yields first node + # if it holds only a node. + def single_expression + single_expression? || self + end + + # It yields `nil` always. + # (It is overrided by `Expressions` to implement `#single_expression`.) + def single_expression? + nil + end end class Nop < ASTNode @@ -146,6 +158,13 @@ module Crystal @end_location || @expressions.last?.try &.end_location end + # It yields first node if this holds only one node, or yields `nil`. + def single_expression? + return @expressions.first.single_expression if @expressions.size == 1 + + nil + end + def accept_children(visitor) @expressions.each &.accept visitor end