Skip to content

Commit

Permalink
Ignore inlining of functions with return statements
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Jun 28, 2022
1 parent 6d36eb3 commit b8f9b16
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
8 changes: 7 additions & 1 deletion grammar/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ def visitModule(self, mod):
self.emit("public:")
self.emit(" bool success;")
self.emit(" bool allow_procedure_calls;")
self.emit(" bool allow_return_stmts;")
self.emit("")
self.emit(" ExprStmtDuplicator(Allocator& al_) : al(al_), success(false), allow_procedure_calls(true) {}")
self.emit(" ExprStmtDuplicator(Allocator& al_) : al(al_), success(false), allow_procedure_calls(true), allow_return_stmts(false) {}")
self.emit("")
self.duplicate_stmt.append((" ASR::stmt_t* duplicate_stmt(ASR::stmt_t* x) {", 0))
self.duplicate_stmt.append((" if( !x ) {", 1))
Expand Down Expand Up @@ -764,6 +765,11 @@ def make_visitor(self, name, fields):
self.duplicate_stmt.append((" success = false;", 4))
self.duplicate_stmt.append((" return nullptr;", 4))
self.duplicate_stmt.append((" }", 3))
elif name == "Return":
self.duplicate_stmt.append((" if( !allow_return_stmts ) {", 3))
self.duplicate_stmt.append((" success = false;", 4))
self.duplicate_stmt.append((" return nullptr;", 4))
self.duplicate_stmt.append((" }", 3))
self.duplicate_stmt.append((" return down_cast<ASR::stmt_t>(duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name, name), 3))
self.duplicate_stmt.append((" }", 2))
elif self.is_expr:
Expand Down
3 changes: 2 additions & 1 deletion run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def main():
x86 = test.get("x86", False)
bin_ = test.get("bin", False)
pass_ = test.get("pass", None)
if pass_ and pass_ not in ["do_loops", "global_stmts", "loop_vectorise"]:
if pass_ and pass_ not in ["do_loops", "global_stmts",
"loop_vectorise", "inline_function_calls"]:
raise Exception("Unknown pass: %s" % pass_)

print(color(style.bold)+"TEST:"+color(style.reset), filename)
Expand Down
7 changes: 4 additions & 3 deletions src/libasr/pass/inline_function_calls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor<InlineFunctionCa
pass_result.reserve(al, 1);
}

void configure_node_duplicator(bool allow_procedure_calls_) {
void configure_node_duplicator(bool allow_procedure_calls_, bool allow_return_stmts_) {
node_duplicator.allow_procedure_calls = allow_procedure_calls_;
node_duplicator.allow_return_stmts = allow_return_stmts_;
}

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -404,9 +405,9 @@ void pass_inline_function_calls(Allocator &al, ASR::TranslationUnit_t &unit,
const std::string& rl_path,
bool inline_external_symbol_calls) {
InlineFunctionCallVisitor v(al, rl_path, inline_external_symbol_calls);
v.configure_node_duplicator(false);
v.configure_node_duplicator(false, false);
v.visit_TranslationUnit(unit);
v.configure_node_duplicator(true);
v.configure_node_duplicator(true, false);
v.visit_TranslationUnit(unit);
LFORTRAN_ASSERT(asr_verify(unit));
}
Expand Down
5 changes: 5 additions & 0 deletions tests/tests.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ filename = "../integration_tests/vec_01.py"
asr = true
pass = "loop_vectorise"

[[test]]
filename = "../integration_tests/func_inline_01.py"
asr = true
pass = "inline_function_calls"

[[test]]
filename = "loop1.py"
ast = true
Expand Down

0 comments on commit b8f9b16

Please sign in to comment.