Skip to content

Commit

Permalink
Improvements to ASR passes from LFortran
Browse files Browse the repository at this point in the history
  • Loading branch information
certik committed Mar 29, 2022
1 parent 162f6a4 commit 411db63
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 99 deletions.
82 changes: 2 additions & 80 deletions src/libasr/pass/do_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <libasr/asr_verify.h>
#include <libasr/pass/do_loops.h>
#include <libasr/pass/stmt_walk_visitor.h>
#include <libasr/pass/pass_utils.h>

namespace LFortran {

Expand Down Expand Up @@ -32,93 +33,14 @@ This ASR pass replaces do loops with while loops. The function
The comparison is >= for c<0.
*/
Vec<ASR::stmt_t*> replace_doloop(Allocator &al, const ASR::DoLoop_t &loop) {
Location loc = loop.base.base.loc;
ASR::expr_t *a=loop.m_head.m_start;
ASR::expr_t *b=loop.m_head.m_end;
ASR::expr_t *c=loop.m_head.m_increment;
ASR::expr_t *cond = nullptr;
ASR::stmt_t *inc_stmt = nullptr;
ASR::stmt_t *stmt1 = nullptr;
if( !a && !b && !c ) {
ASR::ttype_t *cond_type = LFortran::ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0));
cond = LFortran::ASRUtils::EXPR(ASR::make_ConstantLogical_t(al, loc, true, cond_type));
} else {
LFORTRAN_ASSERT(a);
LFORTRAN_ASSERT(b);
if (!c) {
ASR::ttype_t *type = LFortran::ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
c = LFortran::ASRUtils::EXPR(ASR::make_ConstantInteger_t(al, loc, 1, type));
}
LFORTRAN_ASSERT(c);
int increment;
if (c->type == ASR::exprType::ConstantInteger) {
increment = down_cast<ASR::ConstantInteger_t>(c)->m_n;
} else if (c->type == ASR::exprType::UnaryOp) {
ASR::UnaryOp_t *u = down_cast<ASR::UnaryOp_t>(c);
LFORTRAN_ASSERT(u->m_op == ASR::unaryopType::USub);
LFORTRAN_ASSERT(u->m_operand->type == ASR::exprType::ConstantInteger);
increment = - down_cast<ASR::ConstantInteger_t>(u->m_operand)->m_n;
} else {
throw LFortranException("Do loop increment type not supported");
}
ASR::cmpopType cmp_op;
if (increment > 0) {
cmp_op = ASR::cmpopType::LtE;
} else {
cmp_op = ASR::cmpopType::GtE;
}
ASR::expr_t *target = loop.m_head.m_v;
ASR::ttype_t *type = LFortran::ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
stmt1 = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, a, ASR::binopType::Sub, c, type, nullptr, nullptr)),
nullptr));

cond = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)),
cmp_op, b, type, nullptr, nullptr));

inc_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)),
nullptr));
}
Vec<ASR::stmt_t*> body;
body.reserve(al, loop.n_body + (inc_stmt != nullptr));
if( inc_stmt ) {
body.push_back(al, inc_stmt);
}
for (size_t i=0; i<loop.n_body; i++) {
body.push_back(al, loop.m_body[i]);
}
ASR::stmt_t *stmt2 = LFortran::ASRUtils::STMT(ASR::make_WhileLoop_t(al, loc, cond,
body.p, body.size()));
Vec<ASR::stmt_t*> result;
result.reserve(al, 2);
if( stmt1 ) {
result.push_back(al, stmt1);
}
result.push_back(al, stmt2);

/*
std::cout << "Input:" << std::endl;
std::cout << pickle((ASR::asr_t&)loop);
std::cout << "Output:" << std::endl;
std::cout << pickle((ASR::asr_t&)*stmt1);
std::cout << pickle((ASR::asr_t&)*stmt2);
std::cout << "--------------" << std::endl;
*/

return result;
}

class DoLoopVisitor : public ASR::StatementWalkVisitor<DoLoopVisitor>
{
public:
DoLoopVisitor(Allocator &al) : StatementWalkVisitor(al) {
}

void visit_DoLoop(const ASR::DoLoop_t &x) {
pass_result = replace_doloop(al, x);
pass_result = PassUtils::replace_doloop(al, x);
}
};

Expand Down
88 changes: 70 additions & 18 deletions src/libasr/pass/inline_function_calls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,29 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa

std::string current_routine;

bool inline_external_symbol_calls;


ASR::ExprStmtDuplicator node_duplicator;

public:

bool function_inlined;

InlineFunctionCallVisitor(Allocator &al_, const std::string& rl_path_) : PassVisitor(al_, nullptr),
InlineFunctionCallVisitor(Allocator &al_, const std::string& rl_path_, bool inline_external_symbol_calls_)
: PassVisitor(al_, nullptr),
rl_path(rl_path_), function_result_var(nullptr),
from_inline_function_call(false), inlining_function(false),
current_routine(""), node_duplicator(al_), function_inlined(false)
current_routine(""), inline_external_symbol_calls(inline_external_symbol_calls_),
node_duplicator(al_), function_inlined(false)
{
pass_result.reserve(al, 1);
}

void configure_node_duplicator(bool allow_procedure_calls_) {
node_duplicator.allow_procedure_calls = allow_procedure_calls_;
}

void visit_Function(const ASR::Function_t &x) {
// FIXME: this is a hack, we need to pass in a non-const `x`,
// which requires to generate a TransformVisitor.
Expand All @@ -88,6 +97,52 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
}

void visit_FunctionCall(const ASR::FunctionCall_t& x) {
// If this node is visited by any other visitor
// or it is being visited while inlining another function call
// then return. To ensure that only one function call is inlined
// at a time.
if( !from_inline_function_call || inlining_function ) {
if( !inlining_function ) {
return ;
}
// TODO: Handle type later
if( ASR::is_a<ASR::ExternalSymbol_t>(*x.m_name) ) {
ASR::ExternalSymbol_t* called_sym_ext = ASR::down_cast<ASR::ExternalSymbol_t>(x.m_name);
ASR::symbol_t* f_sym = ASRUtils::symbol_get_past_external(called_sym_ext->m_external);
ASR::Function_t* f = ASR::down_cast<ASR::Function_t>(f_sym);

// Never inline intrinsic functions
if( ASRUtils::is_intrinsic_function2(f) ) {
return ;
}

ASR::symbol_t* called_sym = x.m_name;

// TODO: Hanlde later
// ASR::symbol_t* called_sym_original = x.m_original_name;

ASR::FunctionCall_t& xx = const_cast<ASR::FunctionCall_t&>(x);
std::string called_sym_name = std::string(called_sym_ext->m_name);
std::string new_sym_name_str = current_scope->get_unique_name(called_sym_name);
char* new_sym_name = s2c(al, new_sym_name_str);
if( current_scope->scope.find(new_sym_name_str) == current_scope->scope.end() ) {
ASR::Module_t *m = ASR::down_cast2<ASR::Module_t>(f->m_symtab->parent->asr_owner);
char *modname = m->m_name;
ASR::symbol_t* new_sym = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
al, called_sym->base.loc, current_scope, new_sym_name,
f_sym, modname, nullptr, 0,
f->m_name, ASR::accessType::Private));
current_scope->scope[new_sym_name_str] = new_sym;
}
xx.m_name = current_scope->scope[new_sym_name_str];
}

for( size_t i = 0; i < x.n_args; i++ ) {
visit_expr(*x.m_args[i].m_value);
}
return ;
}

// Clear up any local variables present in arg2value map
// due to inlining other function calls
arg2value.clear();
Expand All @@ -98,18 +153,15 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
Vec<ASR::stmt_t*> pass_result_local;
pass_result_local.reserve(al, 1);

// If this node is visited by any other visitor
// or it is being visited while inlining another function call
// then return. To ensure that only one function call is inlined
// at a time.
if( !from_inline_function_call || inlining_function ) {
return ;
}

// Avoid external symbols for now.
ASR::symbol_t* routine = x.m_name;
if( !ASR::is_a<ASR::Function_t>(*routine) ) {
return ;
if( ASR::is_a<ASR::ExternalSymbol_t>(*routine) &&
inline_external_symbol_calls) {
routine = ASRUtils::symbol_get_past_external(x.m_name);
} else {
return ;
}
}

// Avoid inlining current function call if its a recursion.
Expand Down Expand Up @@ -294,13 +346,13 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
};

void pass_inline_function_calls(Allocator &al, ASR::TranslationUnit_t &unit,
const std::string& rl_path) {
InlineFunctionCallVisitor v(al, rl_path);
v.function_inlined = true;
while( v.function_inlined ) {
v.function_inlined = false;
v.visit_TranslationUnit(unit);
}
const std::string& rl_path,
bool inline_external_symbol_calls) {
InlineFunctionCallVisitor v(al, rl_path, inline_external_symbol_calls);
v.configure_node_duplicator(false);
v.visit_TranslationUnit(unit);
v.configure_node_duplicator(true);
v.visit_TranslationUnit(unit);
LFORTRAN_ASSERT(asr_verify(unit));
}

Expand Down
4 changes: 3 additions & 1 deletion src/libasr/pass/inline_function_calls.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

namespace LFortran {

void pass_inline_function_calls(Allocator &al, ASR::TranslationUnit_t &unit, const std::string& rl_path);
void pass_inline_function_calls(Allocator &al, ASR::TranslationUnit_t &unit,
const std::string& rl_path,
bool inline_external_symbol_calls=true);

} // namespace LFortran

Expand Down
70 changes: 70 additions & 0 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,76 @@ namespace LFortran {
loc, v, args, current_scope, al, err));
}

Vec<ASR::stmt_t*> replace_doloop(Allocator &al, const ASR::DoLoop_t &loop) {
Location loc = loop.base.base.loc;
ASR::expr_t *a=loop.m_head.m_start;
ASR::expr_t *b=loop.m_head.m_end;
ASR::expr_t *c=loop.m_head.m_increment;
ASR::expr_t *cond = nullptr;
ASR::stmt_t *inc_stmt = nullptr;
ASR::stmt_t *stmt1 = nullptr;
if( !a && !b && !c ) {
ASR::ttype_t *cond_type = LFortran::ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0));
cond = LFortran::ASRUtils::EXPR(ASR::make_ConstantLogical_t(al, loc, true, cond_type));
} else {
LFORTRAN_ASSERT(a);
LFORTRAN_ASSERT(b);
if (!c) {
ASR::ttype_t *type = LFortran::ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
c = LFortran::ASRUtils::EXPR(ASR::make_ConstantInteger_t(al, loc, 1, type));
}
LFORTRAN_ASSERT(c);
int increment;
if (c->type == ASR::exprType::ConstantInteger) {
increment = ASR::down_cast<ASR::ConstantInteger_t>(c)->m_n;
} else if (c->type == ASR::exprType::UnaryOp) {
ASR::UnaryOp_t *u = ASR::down_cast<ASR::UnaryOp_t>(c);
LFORTRAN_ASSERT(u->m_op == ASR::unaryopType::USub);
LFORTRAN_ASSERT(u->m_operand->type == ASR::exprType::ConstantInteger);
increment = - ASR::down_cast<ASR::ConstantInteger_t>(u->m_operand)->m_n;
} else {
throw LFortranException("Do loop increment type not supported");
}
ASR::cmpopType cmp_op;
if (increment > 0) {
cmp_op = ASR::cmpopType::LtE;
} else {
cmp_op = ASR::cmpopType::GtE;
}
ASR::expr_t *target = loop.m_head.m_v;
ASR::ttype_t *type = LFortran::ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
stmt1 = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, a, ASR::binopType::Sub, c, type, nullptr, nullptr)),
nullptr));

cond = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)),
cmp_op, b, type, nullptr, nullptr));

inc_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target,
LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)),
nullptr));
}
Vec<ASR::stmt_t*> body;
body.reserve(al, loop.n_body + (inc_stmt != nullptr));
if( inc_stmt ) {
body.push_back(al, inc_stmt);
}
for (size_t i=0; i<loop.n_body; i++) {
body.push_back(al, loop.m_body[i]);
}
ASR::stmt_t *stmt2 = LFortran::ASRUtils::STMT(ASR::make_WhileLoop_t(al, loc, cond,
body.p, body.size()));
Vec<ASR::stmt_t*> result;
result.reserve(al, 2);
if( stmt1 ) {
result.push_back(al, stmt1);
}
result.push_back(al, stmt2);

return result;
}

}

}
2 changes: 2 additions & 0 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ namespace LFortran {
SymbolTable*& current_scope, Location& loc,
const std::function<void (const std::string &, const Location &)> err);

Vec<ASR::stmt_t*> replace_doloop(Allocator &al, const ASR::DoLoop_t &loop);

template <class Derived>
class PassVisitor: public ASR::BaseWalkVisitor<Derived> {

Expand Down

0 comments on commit 411db63

Please sign in to comment.