Skip to content

Commit

Permalink
Support visit_Module in pass_array_by_data.cpp (lcompilers#1564)
Browse files Browse the repository at this point in the history
Co-authored-by: Shaikh Ubaid <shaikhubaid769@gmail.com>
  • Loading branch information
czgdp1807 and Shaikh-Ubaid authored Mar 5, 2023
1 parent cc02ba9 commit b1dd300
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 18 deletions.
12 changes: 10 additions & 2 deletions src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,14 +975,16 @@ def visitField(self, field):
field.type == "array_index" or
field.type == "alloc_arg" or
field.type == "case_stmt" or
field.type == "ttype"):
field.type == "ttype" or
field.type == "dimension"):
level = 2
if field.seq:
self.used = True
pointer_char = ''
if (field.type != "call_arg" and
field.type != "array_index" and
field.type != "alloc_arg"):
field.type != "alloc_arg" and
field.type != "dimension"):
pointer_char = '*'
self.emit("Vec<%s_t%s> m_%s;" % (field.type, pointer_char, field.name), level)
self.emit("m_%s.reserve(al, x->n_%s);" % (field.name, field.name), level)
Expand Down Expand Up @@ -1017,6 +1019,12 @@ def visitField(self, field):
self.emit(" array_index_copy.m_right = duplicate_expr(x->m_%s[i].m_right);"%(field.name), level)
self.emit(" array_index_copy.m_step = duplicate_expr(x->m_%s[i].m_step);"%(field.name), level)
self.emit(" m_%s.push_back(al, array_index_copy);"%(field.name), level)
elif field.type == "dimension":
self.emit(" ASR::dimension_t dim_copy;", level)
self.emit(" dim_copy.loc = x->m_%s[i].loc;"%(field.name), level)
self.emit(" dim_copy.m_start = self().duplicate_expr(x->m_%s[i].m_start);"%(field.name), level)
self.emit(" dim_copy.m_length = self().duplicate_expr(x->m_%s[i].m_length);"%(field.name), level)
self.emit(" m_%s.push_back(al, dim_copy);" % (field.name), level)
else:
self.emit(" m_%s.push_back(al, self().duplicate_%s(x->m_%s[i]));" % (field.name, field.type, field.name), level)
self.emit("}", level)
Expand Down
16 changes: 16 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,17 @@ inline int extract_len(ASR::expr_t* len_expr, const Location& loc) {
return a_len;
}

inline bool is_parent(SymbolTable* a, SymbolTable* b) {
SymbolTable* current_parent = b->parent;
while( current_parent ) {
if( current_parent == a ) {
return true;
}
current_parent = current_parent->parent;
}
return false;
}

inline bool is_parent(ASR::StructType_t* a, ASR::StructType_t* b) {
ASR::symbol_t* current_parent = b->m_parent;
while( current_parent ) {
Expand Down Expand Up @@ -2747,6 +2758,11 @@ static inline ASR::expr_t* compute_length_from_start_end(Allocator& al, ASR::exp
}

static inline bool is_pass_array_by_data_possible(ASR::Function_t* x, std::vector<size_t>& v) {
if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindC &&
ASRUtils::get_FunctionType(x)->m_deftype == ASR::deftypeType::Interface) {
return false;
}

ASR::ttype_t* typei = nullptr;
ASR::dimension_t* dims = nullptr;
for( size_t i = 0; i < x->n_args; i++ ) {
Expand Down
139 changes: 123 additions & 16 deletions src/libasr/pass/pass_array_by_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,15 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor<PassArrayB
}
ASR::Var_t& xx = const_cast<ASR::Var_t&>(x);
ASR::symbol_t* x_sym = xx.m_v;
std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym));
if( current_proc_scope->get_symbol(x_sym_name) != x_sym ) {
xx.m_v = current_proc_scope->get_symbol(x_sym_name);
SymbolTable* x_sym_symtab = ASRUtils::symbol_parent_symtab(x_sym);
if( x_sym_symtab->get_counter() != current_proc_scope->get_counter() &&
!ASRUtils::is_parent(x_sym_symtab, current_proc_scope) ) {
// xx.m_v points to the function/procedure present inside
// original function's symtab. Make it point to the symbol in
// new function's symtab.
std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym));
xx.m_v = current_proc_scope->resolve_symbol(x_sym_name);
LCOMPILERS_ASSERT(xx.m_v != nullptr);
}
}

Expand All @@ -74,10 +80,17 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor<PassArrayB
"Cannot deallocate variables in expression " +
std::to_string(tmp_expr->type));
}
std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym));
if( current_proc_scope->get_symbol(x_sym_name) != x_sym ) {
ASR::symbol_t* x_sym_new = current_proc_scope->get_symbol(x_sym_name);
xx.m_a = ASRUtils::EXPR(ASR::make_Var_t(al, x_sym_new->base.loc, x_sym_new));

SymbolTable* x_sym_symtab = ASRUtils::symbol_parent_symtab(x_sym);
if( x_sym_symtab->get_counter() != current_proc_scope->get_counter() &&
!ASRUtils::is_parent(x_sym_symtab, current_proc_scope) ) {
// xx.m_a points to the function/procedure present inside
// original function's symtab. Make it point to the symbol in
// new function's symtab.
std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym));
ASR::symbol_t* x_sym_new = current_proc_scope->resolve_symbol(x_sym_name);
xx.m_a = ASRUtils::EXPR(ASR::make_Var_t(al, x_sym_new->base.loc, x_sym_new));
LCOMPILERS_ASSERT(xx.m_a != nullptr);
}
}

Expand All @@ -88,9 +101,15 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor<PassArrayB
}
ASR::FunctionCall_t& xx = const_cast<ASR::FunctionCall_t&>(x);
ASR::symbol_t* x_sym = xx.m_name;
std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym));
if( current_proc_scope->get_symbol(x_sym_name) != x_sym ) {
xx.m_name = current_proc_scope->get_symbol(x_sym_name);
SymbolTable* x_sym_symtab = ASRUtils::symbol_parent_symtab(x_sym);
if( x_sym_symtab->get_counter() != current_proc_scope->get_counter() &&
!ASRUtils::is_parent(x_sym_symtab, current_proc_scope) ) {
// xx.m_name points to the function/procedure present inside
// original function's symtab. Make it point to the symbol in
// new function's symtab.
std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym));
xx.m_name = current_proc_scope->resolve_symbol(x_sym_name);
LCOMPILERS_ASSERT(xx.m_name != nullptr);
}
}

Expand All @@ -107,15 +126,32 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor<PassArrayB
}
new_body.push_back(al, new_stmt);
}

node_duplicator.allow_procedure_calls = true;
SymbolTable* new_symtab = al.make_new<SymbolTable>(current_scope);
for( auto& item: x->m_symtab->get_scope() ) {
ASR::symbol_t* new_arg = nullptr;
if( ASR::is_a<ASR::Variable_t>(*item.second) ) {
ASR::Variable_t* arg = ASR::down_cast<ASR::Variable_t>(item.second);
node_duplicator.success = true;
ASR::expr_t* m_symbolic_value = node_duplicator.duplicate_expr(arg->m_symbolic_value);
if( !node_duplicator.success ) {
return nullptr;
}
node_duplicator.success = true;
ASR::expr_t* m_value = node_duplicator.duplicate_expr(arg->m_value);
if( !node_duplicator.success ) {
return nullptr;
}
node_duplicator.success = true;
ASR::ttype_t* m_type = node_duplicator.duplicate_ttype(arg->m_type);
if( !node_duplicator.success ) {
return nullptr;
}
new_arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(al,
arg->base.base.loc, new_symtab, s2c(al, item.first),
nullptr, 0, arg->m_intent, arg->m_symbolic_value, arg->m_value,
arg->m_storage, arg->m_type, arg->m_abi, arg->m_access,
nullptr, 0, arg->m_intent, m_symbolic_value, m_value,
arg->m_storage, m_type, arg->m_abi, arg->m_access,
arg->m_presence, arg->m_value_attr));
} else if( ASR::is_a<ASR::ExternalSymbol_t>(*item.second) ) {
ASR::ExternalSymbol_t* arg = ASR::down_cast<ASR::ExternalSymbol_t>(item.second);
Expand Down Expand Up @@ -222,8 +258,25 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor<PassArrayB
current_proc_scope = nullptr;
}

void visit_Program(const ASR::Program_t& x) {
ASR::Program_t& xx = const_cast<ASR::Program_t&>(x);
void visit_TranslationUnit(const ASR::TranslationUnit_t& x) {
// Visit Module first so that all functions in it are updated
for (auto &a : x.m_global_scope->get_scope()) {
if( ASR::is_a<ASR::Module_t>(*a.second) ) {
this->visit_symbol(*a.second);
}
}

// Visit all other symbols
for (auto &a : x.m_global_scope->get_scope()) {
if( !ASR::is_a<ASR::Module_t>(*a.second) ) {
this->visit_symbol(*a.second);
}
}
}

template <typename T>
void visit_SymbolContainingFunctions(const T& x) {
T& xx = const_cast<T&>(x);
current_scope = xx.m_symtab;
for( auto& item: xx.m_symtab->get_scope() ) {
if( ASR::is_a<ASR::Function_t>(*item.second) ) {
Expand All @@ -239,6 +292,19 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor<PassArrayB
}
}
}

void visit_Program(const ASR::Program_t& x) {
visit_SymbolContainingFunctions(x);
}

void visit_Module(const ASR::Module_t& x) {
// Do not visit intrinsic modules
if( x.m_intrinsic ) {
return ;
}

visit_SymbolContainingFunctions(x);
}
};

/*
Expand Down Expand Up @@ -269,6 +335,8 @@ class ReplaceSubroutineCallsVisitor : public PassUtils::PassVisitor<ReplaceSubro

void visit_SubroutineCall(const ASR::SubroutineCall_t& x) {
ASR::symbol_t* subrout_sym = x.m_name;
bool is_external = ASR::is_a<ASR::ExternalSymbol_t>(*subrout_sym);
subrout_sym = ASRUtils::symbol_get_past_external(subrout_sym);
if( v.proc2newproc.find(subrout_sym) == v.proc2newproc.end() ) {
return ;
}
Expand All @@ -295,8 +363,25 @@ class ReplaceSubroutineCallsVisitor : public PassUtils::PassVisitor<ReplaceSubro
}
}

ASR::symbol_t* new_subrout_sym_ = new_subrout_sym;
if( is_external ) {
ASR::ExternalSymbol_t* subrout_ext_sym = ASR::down_cast<ASR::ExternalSymbol_t>(x.m_name);
// TODO: Use SymbolTable::get_unique_name to avoid potential
// clashes with user defined functions
char* new_subrout_sym_name = ASRUtils::symbol_name(new_subrout_sym);
if( current_scope->get_symbol(new_subrout_sym_name) == nullptr ) {
new_subrout_sym_ = ASR::down_cast<ASR::symbol_t>(
ASR::make_ExternalSymbol_t(al, x.m_name->base.loc, subrout_ext_sym->m_parent_symtab,
new_subrout_sym_name, new_subrout_sym, subrout_ext_sym->m_module_name,
subrout_ext_sym->m_scope_names, subrout_ext_sym->n_scope_names, new_subrout_sym_name,
subrout_ext_sym->m_access));
current_scope->add_symbol(new_subrout_sym_name, new_subrout_sym_);
} else {
new_subrout_sym_ = current_scope->get_symbol(new_subrout_sym_name);
}
}
ASR::stmt_t* new_call = ASRUtils::STMT(ASR::make_SubroutineCall_t(al,
x.base.base.loc, new_subrout_sym, new_subrout_sym,
x.base.base.loc, new_subrout_sym_, new_subrout_sym_,
new_args.p, new_args.size(), x.m_dt));
pass_result.push_back(al, new_call);
}
Expand Down Expand Up @@ -325,11 +410,15 @@ class ReplaceFunctionCalls: public ASR::BaseExprReplacer<ReplaceFunctionCalls> {

public:

SymbolTable* current_scope;

ReplaceFunctionCalls(Allocator& al_, PassArrayByDataProcedureVisitor& v_) : al(al_), v(v_)
{}

void replace_FunctionCall(ASR::FunctionCall_t* x) {
ASR::symbol_t* subrout_sym = x->m_name;
bool is_external = ASR::is_a<ASR::ExternalSymbol_t>(*subrout_sym);
subrout_sym = ASRUtils::symbol_get_past_external(subrout_sym);
if( v.proc2newproc.find(subrout_sym) == v.proc2newproc.end() ) {
return ;
}
Expand Down Expand Up @@ -357,8 +446,25 @@ class ReplaceFunctionCalls: public ASR::BaseExprReplacer<ReplaceFunctionCalls> {
}

LCOMPILERS_ASSERT(new_args.size() == ASR::down_cast<ASR::Function_t>(new_func_sym)->n_args);
ASR::symbol_t* new_func_sym_ = new_func_sym;
if( is_external ) {
ASR::ExternalSymbol_t* func_ext_sym = ASR::down_cast<ASR::ExternalSymbol_t>(x->m_name);
// TODO: Use SymbolTable::get_unique_name to avoid potential
// clashes with user defined functions
char* new_func_sym_name = ASRUtils::symbol_name(new_func_sym);
if( current_scope->get_symbol(new_func_sym_name) == nullptr ) {
new_func_sym_ = ASR::down_cast<ASR::symbol_t>(
ASR::make_ExternalSymbol_t(al, x->m_name->base.loc, func_ext_sym->m_parent_symtab,
new_func_sym_name, new_func_sym, func_ext_sym->m_module_name,
func_ext_sym->m_scope_names, func_ext_sym->n_scope_names, new_func_sym_name,
func_ext_sym->m_access));
current_scope->add_symbol(new_func_sym_name, new_func_sym_);
} else {
new_func_sym_ = current_scope->get_symbol(new_func_sym_name);
}
}
ASR::expr_t* new_call = ASRUtils::EXPR(ASR::make_FunctionCall_t(al,
x->base.base.loc, new_func_sym, new_func_sym,
x->base.base.loc, new_func_sym_, new_func_sym_,
new_args.p, new_args.size(), x->m_type, nullptr,
x->m_dt));
*current_expr = new_call;
Expand All @@ -384,6 +490,7 @@ class ReplaceFunctionCallsVisitor : public ASR::CallReplacerOnExpressionsVisitor

void call_replacer() {
replacer.current_expr = current_expr;
replacer.current_scope = current_scope;
replacer.replace_expr(*current_expr);
}

Expand Down

0 comments on commit b1dd300

Please sign in to comment.