Skip to content

Commit

Permalink
Merge pull request lcompilers#296 from certik/asr_update
Browse files Browse the repository at this point in the history
ASR updates from LFortran
  • Loading branch information
certik committed Mar 29, 2022
2 parents 014463b + 0be6e87 commit 64e0c33
Show file tree
Hide file tree
Showing 11 changed files with 309 additions and 116 deletions.
38 changes: 22 additions & 16 deletions grammar/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,9 @@ def visitModule(self, mod):
self.emit("")
self.emit("public:")
self.emit(" bool success;")
self.emit(" bool allow_procedure_calls;")
self.emit("")
self.emit(" ExprStmtDuplicator(Allocator& al_) : al(al_), success(false) {}")
self.emit(" ExprStmtDuplicator(Allocator& al_) : al(al_), success(false), allow_procedure_calls(true) {}")
self.emit("")
self.duplicate_stmt.append((" ASR::stmt_t* duplicate_stmt(ASR::stmt_t* x) {", 0))
self.duplicate_stmt.append((" if( !x ) {", 1))
Expand Down Expand Up @@ -743,19 +744,6 @@ def visitConstructor(self, cons, _):
self.make_visitor(cons.name, cons.fields)

def make_visitor(self, name, fields):
if name == "FunctionCall" or name == "SubroutineCall":
if self.is_stmt:
self.duplicate_stmt.append((" case ASR::stmtType::%s: {" % name, 2))
self.duplicate_stmt.append((" success = false;", 3))
self.duplicate_stmt.append((" return nullptr;", 3))
self.duplicate_stmt.append((" }", 2))
elif self.is_expr:
self.duplicate_expr.append((" case ASR::exprType::%s: {" % name, 2))
self.duplicate_expr.append((" success = false;", 3))
self.duplicate_expr.append((" return nullptr;", 3))
self.duplicate_expr.append((" }", 2))
return None

self.emit("")
self.emit("ASR::asr_t* duplicate_%s(%s_t* x) {" % (name, name), 1)
self.used = False
Expand All @@ -771,26 +759,44 @@ def make_visitor(self, name, fields):
self.emit("return make_%s_t(al, x->base.base.loc, %s);" %(name, node_arg_str), 2)
if self.is_stmt:
self.duplicate_stmt.append((" case ASR::stmtType::%s: {" % name, 2))
if name == "SubroutineCall":
self.duplicate_stmt.append((" if( !allow_procedure_calls ) {", 3))
self.duplicate_stmt.append((" success = false;", 4))
self.duplicate_stmt.append((" return nullptr;", 4))
self.duplicate_stmt.append((" }", 3))
self.duplicate_stmt.append((" return down_cast<ASR::stmt_t>(duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name, name), 3))
self.duplicate_stmt.append((" }", 2))
elif self.is_expr:
self.duplicate_expr.append((" case ASR::exprType::%s: {" % name, 2))
if name == "FunctionCall":
self.duplicate_expr.append((" if( !allow_procedure_calls ) {", 3))
self.duplicate_expr.append((" success = false;", 4))
self.duplicate_expr.append((" return nullptr;", 4))
self.duplicate_expr.append((" }", 3))
self.duplicate_expr.append((" return down_cast<ASR::expr_t>(duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name, name), 3))
self.duplicate_expr.append((" }", 2))
self.emit("}", 1)
self.emit("")

def visitField(self, field):
arguments = None
if field.type == "expr" or field.type == "stmt" or field.type == "symbol":
if field.type == "expr" or field.type == "stmt" or field.type == "symbol" or field.type == "call_arg":
level = 2
if field.seq:
self.used = True
self.emit("Vec<%s_t*> m_%s;" % (field.type, field.name), level)
pointer_char = ''
if field.type != "call_arg":
pointer_char = '*'
self.emit("Vec<%s_t%s> m_%s;" % (field.type, pointer_char, field.name), level)
self.emit("m_%s.reserve(al, x->n_%s);" % (field.name, field.name), level)
self.emit("for (size_t i = 0; i < x->n_%s; i++) {" % field.name, level)
if field.type == "symbol":
self.emit(" m_%s.push_back(al, x->m_%s[i]);" % (field.name, field.name), level)
elif field.type == "call_arg":
self.emit(" ASR::call_arg_t call_arg_copy;", level)
self.emit(" call_arg_copy.loc = x->m_%s[i].loc;"%(field.name), level)
self.emit(" call_arg_copy.m_value = duplicate_expr(x->m_%s[i].m_value);"%(field.name), level)
self.emit(" m_%s.push_back(al, call_arg_copy);"%(field.name), level)
else:
self.emit(" m_%s.push_back(al, duplicate_%s(x->m_%s[i]));" % (field.name, field.type, field.name), level)
self.emit("}", level)
Expand Down
1 change: 1 addition & 0 deletions src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ set(SRC
pass/fma.cpp
pass/sign_from_value.cpp
pass/inline_function_calls.cpp
pass/loop_unroll.cpp

asr_verify.cpp
asr_utils.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
"Var_t::m_v cannot be nullptr");
require(is_a<Variable_t>(*x.m_v) || is_a<ExternalSymbol_t>(*x.m_v)
|| is_a<Function_t>(*x.m_v) || is_a<Subroutine_t>(*x.m_v),
"Var_t::m_v does not point to a Variable_t, ExternalSymbol_t," \
"Var_t::m_v " + std::string(ASRUtils::symbol_name(x.m_v)) + " does not point to a Variable_t, ExternalSymbol_t," \
"Function_t, or Subroutine_t");
require(symtab_in_scope(current_symtab, x.m_v),
"Var::m_v `" + std::string(ASRUtils::symbol_name(x.m_v)) + "` cannot point outside of its symbol table");
Expand Down
6 changes: 6 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include <libasr/pass/flip_sign.h>
#include <libasr/pass/div_to_mul.h>
#include <libasr/pass/fma.h>
#include <libasr/pass/loop_unroll.h>
#include <libasr/pass/sign_from_value.h>
#include <libasr/pass/class_constructor.h>
#include <libasr/pass/unused_functions.h>
Expand Down Expand Up @@ -4160,6 +4161,11 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,
pass_replace_arr_slice(al, asr, rl_path);
pass_replace_array_op(al, asr, rl_path);
pass_replace_print_arr(al, asr, rl_path);

if( fast ) {
pass_loop_unroll(al, asr, rl_path);
}

pass_replace_do_loops(al, asr);
pass_replace_forall(al, asr);
pass_replace_select_case(al, asr);
Expand Down
82 changes: 2 additions & 80 deletions src/libasr/pass/do_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <libasr/asr_verify.h>
#include <libasr/pass/do_loops.h>
#include <libasr/pass/stmt_walk_visitor.h>
#include <libasr/pass/pass_utils.h>

namespace LFortran {

Expand Down Expand Up @@ -32,93 +33,14 @@ This ASR pass replaces do loops with while loops. The function
The comparison is >= for c<0.
*/
Vec<ASR::stmt_t*> replace_doloop(Allocator &al, const ASR::DoLoop_t &loop) {
Location loc = loop.base.base.loc;
ASR::expr_t *a=loop.m_head.m_start;
ASR::expr_t *b=loop.m_head.m_end;
ASR::expr_t *c=loop.m_head.m_increment;
ASR::expr_t *cond = nullptr;
ASR::stmt_t *inc_stmt = nullptr;
ASR::stmt_t *stmt1 = nullptr;
if( !a && !b && !c ) {
ASR::ttype_t *cond_type = LFortran::ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0));
cond = LFortran::ASRUtils::EXPR(ASR::make_ConstantLogical_t(al, loc, true, cond_type));
} else {
LFORTRAN_ASSERT(a);
LFORTRAN_ASSERT(b);
if (!c) {
ASR::ttype_t *type = LFortran::ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
c = LFortran::ASRUtils::EXPR(ASR::make_ConstantInteger_t(al, loc, 1, type));
}
LFORTRAN_ASSERT(c);
int increment;
if (c->type == ASR::exprType::ConstantInteger) {
increment = down_cast<ASR::ConstantInteger_t>(c)->m_n;
} else if (c->type == ASR::exprType::UnaryOp) {
ASR::UnaryOp_t *u = down_cast<ASR::UnaryOp_t>(c);
LFORTRAN_ASSERT(u->m_op == ASR::unaryopType::USub);
LFORTRAN_ASSERT(u->m_operand->type == ASR::exprType::ConstantInteger);
increment = - down_cast<ASR::ConstantInteger_t>(u->m_operand)->m_n;
} else {
throw LFortranException("Do loop increment type not supported");
}
ASR::cmpopType cmp_op;
if (increment > 0) {
cmp_op = ASR::cmpopType::LtE;
} else {
cmp_op = ASR::cmpopType::GtE;
}
ASR::expr_t *target = loop.m_head.m_v;
ASR::ttype_t *type = LFortran::ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
stmt1 = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, a, ASR::binopType::Sub, c, type, nullptr, nullptr)),
nullptr));

cond = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)),
cmp_op, b, type, nullptr, nullptr));

inc_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)),
nullptr));
}
Vec<ASR::stmt_t*> body;
body.reserve(al, loop.n_body + (inc_stmt != nullptr));
if( inc_stmt ) {
body.push_back(al, inc_stmt);
}
for (size_t i=0; i<loop.n_body; i++) {
body.push_back(al, loop.m_body[i]);
}
ASR::stmt_t *stmt2 = LFortran::ASRUtils::STMT(ASR::make_WhileLoop_t(al, loc, cond,
body.p, body.size()));
Vec<ASR::stmt_t*> result;
result.reserve(al, 2);
if( stmt1 ) {
result.push_back(al, stmt1);
}
result.push_back(al, stmt2);

/*
std::cout << "Input:" << std::endl;
std::cout << pickle((ASR::asr_t&)loop);
std::cout << "Output:" << std::endl;
std::cout << pickle((ASR::asr_t&)*stmt1);
std::cout << pickle((ASR::asr_t&)*stmt2);
std::cout << "--------------" << std::endl;
*/

return result;
}

class DoLoopVisitor : public ASR::StatementWalkVisitor<DoLoopVisitor>
{
public:
DoLoopVisitor(Allocator &al) : StatementWalkVisitor(al) {
}

void visit_DoLoop(const ASR::DoLoop_t &x) {
pass_result = replace_doloop(al, x);
pass_result = PassUtils::replace_doloop(al, x);
}
};

Expand Down
88 changes: 70 additions & 18 deletions src/libasr/pass/inline_function_calls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,29 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa

std::string current_routine;

bool inline_external_symbol_calls;


ASR::ExprStmtDuplicator node_duplicator;

public:

bool function_inlined;

InlineFunctionCallVisitor(Allocator &al_, const std::string& rl_path_) : PassVisitor(al_, nullptr),
InlineFunctionCallVisitor(Allocator &al_, const std::string& rl_path_, bool inline_external_symbol_calls_)
: PassVisitor(al_, nullptr),
rl_path(rl_path_), function_result_var(nullptr),
from_inline_function_call(false), inlining_function(false),
current_routine(""), node_duplicator(al_), function_inlined(false)
current_routine(""), inline_external_symbol_calls(inline_external_symbol_calls_),
node_duplicator(al_), function_inlined(false)
{
pass_result.reserve(al, 1);
}

void configure_node_duplicator(bool allow_procedure_calls_) {
node_duplicator.allow_procedure_calls = allow_procedure_calls_;
}

void visit_Function(const ASR::Function_t &x) {
// FIXME: this is a hack, we need to pass in a non-const `x`,
// which requires to generate a TransformVisitor.
Expand All @@ -88,6 +97,52 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
}

void visit_FunctionCall(const ASR::FunctionCall_t& x) {
// If this node is visited by any other visitor
// or it is being visited while inlining another function call
// then return. To ensure that only one function call is inlined
// at a time.
if( !from_inline_function_call || inlining_function ) {
if( !inlining_function ) {
return ;
}
// TODO: Handle type later
if( ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) ) {
ASR::ExternalSymbol_t* called_sym_ext = ASR::down_cast<ASR::ExternalSymbol_t>(x.m_name);
ASR::symbol_t* f_sym = ASRUtils::symbol_get_past_external(called_sym_ext->m_external);
ASR::Function_t* f = ASR::down_cast<ASR::Function_t>(f_sym);

// Never inline intrinsic functions
if( ASRUtils::is_intrinsic_function2(f) ) {
return ;
}

ASR::symbol_t* called_sym = x.m_name;

// TODO: Hanlde later
// ASR::symbol_t* called_sym_original = x.m_original_name;

ASR::FunctionCall_t& xx = const_cast<ASR::FunctionCall_t&>(x);
std::string called_sym_name = std::string(called_sym_ext->m_name);
std::string new_sym_name_str = current_scope->get_unique_name(called_sym_name);
char* new_sym_name = s2c(al, new_sym_name_str);
if( current_scope->scope.find(new_sym_name_str) == current_scope->scope.end() ) {
ASR::Module_t *m = ASR::down_cast2<ASR::Module_t>(f->m_symtab->parent->asr_owner);
char *modname = m->m_name;
ASR::symbol_t* new_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
al, called_sym->base.loc, current_scope, new_sym_name,
f_sym, modname, nullptr, 0,
f->m_name, ASR::accessType::Private));
current_scope->scope[new_sym_name_str] = new_sym;
}
xx.m_name = current_scope->scope[new_sym_name_str];
}

for( size_t i = 0; i < x.n_args; i++ ) {
visit_expr(*x.m_args[i].m_value);
}
return ;
}

// Clear up any local variables present in arg2value map
// due to inlining other function calls
arg2value.clear();
Expand All @@ -98,18 +153,15 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
Vec<ASR::stmt_t*> pass_result_local;
pass_result_local.reserve(al, 1);

// If this node is visited by any other visitor
// or it is being visited while inlining another function call
// then return. To ensure that only one function call is inlined
// at a time.
if( !from_inline_function_call || inlining_function ) {
return ;
}

// Avoid external symbols for now.
ASR::symbol_t* routine = x.m_name;
if( !ASR::is_a<ASR::Function_t>(*routine) ) {
return ;
if( ASR::is_a<ASR::ExternalSymbol_t>(*routine) &&
inline_external_symbol_calls) {
routine = ASRUtils::symbol_get_past_external(x.m_name);
} else {
return ;
}
}

// Avoid inlining current function call if its a recursion.
Expand Down Expand Up @@ -294,13 +346,13 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
};

void pass_inline_function_calls(Allocator &al, ASR::TranslationUnit_t &unit,
const std::string& rl_path) {
InlineFunctionCallVisitor v(al, rl_path);
v.function_inlined = true;
while( v.function_inlined ) {
v.function_inlined = false;
v.visit_TranslationUnit(unit);
}
const std::string& rl_path,
bool inline_external_symbol_calls) {
InlineFunctionCallVisitor v(al, rl_path, inline_external_symbol_calls);
v.configure_node_duplicator(false);
v.visit_TranslationUnit(unit);
v.configure_node_duplicator(true);
v.visit_TranslationUnit(unit);
LFORTRAN_ASSERT(asr_verify(unit));
}

Expand Down
4 changes: 3 additions & 1 deletion src/libasr/pass/inline_function_calls.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

namespace LFortran {

void pass_inline_function_calls(Allocator &al, ASR::TranslationUnit_t &unit, const std::string& rl_path);
void pass_inline_function_calls(Allocator &al, ASR::TranslationUnit_t &unit,
const std::string& rl_path,
bool inline_external_symbol_calls=true);

} // namespace LFortran

Expand Down
Loading

0 comments on commit 64e0c33

Please sign in to comment.