Skip to content

Commit

Permalink
Modify ASR::Function_t returning an array instead of creating a n…
Browse files Browse the repository at this point in the history
…ew one (lcompilers#1745)
  • Loading branch information
Thirumalai-Shaktivel authored Apr 24, 2023
1 parent 91bc4aa commit baf30f7
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 310 deletions.
195 changes: 39 additions & 156 deletions src/libasr/pass/array_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,13 +758,6 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
}

void replace_FunctionCall(ASR::FunctionCall_t* x) {
std::string x_name;
if( x->m_name->type == ASR::symbolType::ExternalSymbol ) {
x_name = ASR::down_cast<ASR::ExternalSymbol_t>(x->m_name)->m_name;
} else if( x->m_name->type == ASR::symbolType::Function ) {
x_name = ASR::down_cast<ASR::Function_t>(x->m_name)->m_name;
}

// The following checks if the name of a function actually
// points to a subroutine. If true this would mean that the
// original function returned an array and is now a subroutine.
Expand All @@ -776,9 +769,13 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
}

const Location& loc = x->base.base.loc;
ASR::symbol_t *sub = current_scope->resolve_symbol(x_name);
if (sub && ASR::is_a<ASR::Function_t>(*sub)
&& ASR::down_cast<ASR::Function_t>(sub)->m_return_var == nullptr) {
bool is_return_var_handled = false;
ASR::symbol_t *fn_name = ASRUtils::symbol_get_past_external(x->m_name);
if (ASR::is_a<ASR::Function_t>(*fn_name)) {
ASR::Function_t *fn = ASR::down_cast<ASR::Function_t>(fn_name);
is_return_var_handled = fn->m_return_var == nullptr;
}
if (is_return_var_handled) {
bool is_dimension_empty = false;
ASR::ttype_t* result_var_type = x->m_type;
ASR::dimension_t* m_dims = nullptr;
Expand All @@ -792,9 +789,17 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
if( result_type && is_dimension_empty ) {
result_var_type = result_type;
}
// TODO: Remove allocatable attribute from temporary variable

ASR::storage_typeType storage = ASR::storage_typeType::Default;
if (result_var != nullptr && ASR::is_a<ASR::Var_t>(*result_var)) {
ASR::Var_t *var = ASR::down_cast<ASR::Var_t>(result_var);
if (ASR::is_a<ASR::Variable_t>(*var->m_v)) {
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(var->m_v);
storage = v->m_storage;
}
}
ASR::expr_t* result_var_ = PassUtils::create_var(result_counter, "_func_call_res",
loc, result_var_type, al, current_scope);
loc, result_var_type, al, current_scope, storage);
result_counter += 1;
if( result_var == nullptr ) {
result_var = result_var_;
Expand All @@ -812,7 +817,7 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
result_arg.m_value = *current_expr;
s_args.push_back(al, result_arg);
ASR::stmt_t* subrout_call = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc,
sub, nullptr, s_args.p, s_args.size(), nullptr));
x->m_name, nullptr, s_args.p, s_args.size(), nullptr));
pass_result.push_back(al, subrout_call);
apply_again = true;
remove_original_statement = false;
Expand Down Expand Up @@ -972,104 +977,47 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit
pass_result.n = 0;
}

ASR::symbol_t* create_subroutine_from_function(ASR::Function_t* s) {
for( auto& s_item: s->m_symtab->get_scope() ) {
ASR::symbol_t* curr_sym = s_item.second;
if( curr_sym->type == ASR::symbolType::Variable ) {
ASR::Variable_t* var = ASR::down_cast<ASR::Variable_t>(curr_sym);
if( var->m_intent == ASR::intentType::Unspecified ) {
var->m_intent = ASR::intentType::In;
} else if( var->m_intent == ASR::intentType::ReturnVar ) {
var->m_intent = ASR::intentType::Out;
}
}
}
Vec<ASR::expr_t*> a_args;
a_args.reserve(al, s->n_args + 1);
for( size_t i = 0; i < s->n_args; i++ ) {
a_args.push_back(al, s->m_args[i]);
}
LCOMPILERS_ASSERT(s->m_return_var)
a_args.push_back(al, s->m_return_var);
ASR::FunctionType_t* s_func_type = ASR::down_cast<ASR::FunctionType_t>(s->m_function_signature);
ASR::asr_t* s_sub_asr = ASRUtils::make_Function_t_util(al, s->base.base.loc,
s->m_symtab, s->m_name, s->m_dependencies, s->n_dependencies,
a_args.p, a_args.size(), s->m_body, s->n_body,
nullptr, s_func_type->m_abi, s->m_access, s_func_type->m_deftype,
nullptr, false, false, false, s_func_type->m_inline, s_func_type->m_static,
s_func_type->m_type_params, s_func_type->n_type_params, s_func_type->m_restrictions,
s_func_type->n_restrictions, s_func_type->m_is_restriction, s->m_deterministic,
s->m_side_effect_free);
ASR::symbol_t* s_sub = ASR::down_cast<ASR::symbol_t>(s_sub_asr);
return s_sub;
}

// TODO: Only Program and While is processed, we need to process all calls
// to visit_stmt().
// TODO: Only TranslationUnit's and Program's symbol table is processed
// for transforming function->subroutine if they return arrays
void visit_TranslationUnit(const ASR::TranslationUnit_t &x) {
SymbolTable* current_scope_copy = current_scope;
current_scope = x.m_global_scope;
std::vector<std::pair<std::string, ASR::symbol_t*>> replace_vec;
// Transform functions returning arrays to subroutines
for (auto &item : x.m_global_scope->get_scope()) {
if (is_a<ASR::Function_t>(*item.second)) {
ASR::Function_t *s = down_cast<ASR::Function_t>(item.second);
if (s->m_return_var) {
/*
* A function which returns an array will be converted
* to a subroutine with the destination array as the last
* argument. This helps in avoiding deep copies and the
* destination memory directly gets filled inside the subroutine.
*/
if( PassUtils::is_array(s->m_return_var) ) {
ASR::symbol_t* s_sub = create_subroutine_from_function(s);
replace_vec.push_back(std::make_pair(item.first, s_sub));
}
}
PassUtils::handle_fn_return_var(al,
ASR::down_cast<ASR::Function_t>(item.second),
PassUtils::is_array);
}
}

// FIXME: this is a hack, we need to pass in a non-const `x`,
// which requires to generate a TransformVisitor.
ASR::TranslationUnit_t &xx = const_cast<ASR::TranslationUnit_t&>(x);
// Updating the symbol table so that now the name
// of the function (which returned array) now points
// to the newly created subroutine.
for( auto& item: replace_vec ) {
xx.m_global_scope->overwrite_symbol(item.first, item.second);
std::vector<std::string> build_order
= ASRUtils::determine_module_dependencies(x);
for (auto &item : build_order) {
LCOMPILERS_ASSERT(x.m_global_scope->get_symbol(item));
ASR::symbol_t *mod = x.m_global_scope->get_symbol(item);
visit_symbol(*mod);
}

// Now visit everything else
for (auto &item : x.m_global_scope->get_scope()) {
this->visit_symbol(*item.second);
if (!ASR::is_a<ASR::Module_t>(*item.second)) {
this->visit_symbol(*item.second);
}
}
current_scope = current_scope_copy;
}

void visit_Module(const ASR::Module_t &x) {
// FIXME: this is a hack, we need to pass in a non-const `x`,
// which requires to generate a TransformVisitor.
ASR::Module_t &xx = const_cast<ASR::Module_t&>(x);
SymbolTable* current_scope_copy = current_scope;
current_scope = xx.m_symtab;
current_scope = x.m_symtab;
for (auto &item : x.m_symtab->get_scope()) {
if (is_a<ASR::Function_t>(*item.second)) {
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(item.second);
if (s->m_return_var) {
/*
* A function which returns an array will be converted
* to a subroutine with the destination array as the last
* argument. This helps in avoiding deep copies and the
* destination memory directly gets filled inside the subroutine.
*/
if( PassUtils::is_array(s->m_return_var) ) {
ASR::symbol_t* s_sub = create_subroutine_from_function(s);
// Update the symtab with this function changes
xx.m_symtab->overwrite_symbol(item.first, s_sub);
}
}
PassUtils::handle_fn_return_var(al,
ASR::down_cast<ASR::Function_t>(item.second),
PassUtils::is_array);
}
}

Expand All @@ -1086,88 +1034,23 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit
ASR::Program_t& xx = const_cast<ASR::Program_t&>(x);
SymbolTable* current_scope_copy = current_scope;
current_scope = xx.m_symtab;
std::vector<std::pair<std::string, ASR::symbol_t*> > replace_vec;

for (auto &item : x.m_symtab->get_scope()) {
if (is_a<ASR::Function_t>(*item.second)) {
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(item.second);
if (s->m_return_var) {
/*
* A function which returns an array will be converted
* to a subroutine with the destination array as the last
* argument. This helps in avoiding deep copies and the
* destination memory directly gets filled inside the subroutine.
*/
if( PassUtils::is_array(s->m_return_var) ) {
ASR::symbol_t* s_sub = create_subroutine_from_function(s);
replace_vec.push_back(std::make_pair(item.first, s_sub));
bool is_arg = false;
size_t arg_index = 0;
for( size_t i = 0; i < xx.n_body; i++ ) {
ASR::stmt_t* stm = xx.m_body[i];
if( stm->type == ASR::stmtType::SubroutineCall ) {
ASR::SubroutineCall_t *subrout_call = ASR::down_cast<ASR::SubroutineCall_t>(stm);
for ( size_t j = 0; j < subrout_call->n_args; j++ ) {
ASR::expr_t* arg_value = subrout_call->m_args[j].m_value;
if( arg_value->type == ASR::exprType::Var ) {
ASR::Var_t* var = ASR::down_cast<ASR::Var_t>(arg_value);
ASR::symbol_t* sym = var->m_v;
if ( sym->type == ASR::symbolType::Function ) {
ASR::Function_t* subrout = ASR::down_cast<ASR::Function_t>(sym);
std::string subrout_name = std::string(subrout->m_name);
if ( subrout_name == item.first ) {
is_arg = true;
arg_index = j;
ASR::call_arg_t new_call_arg;
new_call_arg.loc = subrout_call->m_args[j].loc;
new_call_arg.m_value = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, var->base.base.loc, s_sub));
subrout_call->m_args[j] = new_call_arg;
}
}

}
}
if ( is_arg ) {
ASR::symbol_t* subrout = subrout_call->m_name;
if ( subrout->type == ASR::symbolType::Function ) {
ASR::Function_t* subrout_func = ASR::down_cast<ASR::Function_t>(subrout);
std::string subrout_func_name = std::string(subrout_func->m_name);
ASR::expr_t* arg = subrout_func->m_args[arg_index];
if( arg->type == ASR::exprType::Var ) {
ASR::Var_t* var = ASR::down_cast<ASR::Var_t>(arg);
ASR::symbol_t* sym = var->m_v;
if ( sym->type == ASR::symbolType::Function ) {
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(sym);
ASR::symbol_t* s_func = create_subroutine_from_function(ASR::down_cast<ASR::Function_t>(sym));
subrout_func->m_symtab->overwrite_symbol(func->m_name, s_func);
subrout_func->m_args[arg_index] = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, var->base.base.loc, s_func));
}
}

}
}
}
}
}
}
PassUtils::handle_fn_return_var(al,
ASR::down_cast<ASR::Function_t>(item.second),
PassUtils::is_array);
}
}

// Updating the symbol table so that now the name
// of the function (which returned array) now points
// to the newly created subroutine.
for( auto& item: replace_vec ) {
current_scope->overwrite_symbol(item.first, item.second);
}

for (auto &item : x.m_symtab->get_scope()) {
if (is_a<ASR::AssociateBlock_t>(*item.second)) {
ASR::AssociateBlock_t *s = ASR::down_cast<ASR::AssociateBlock_t>(item.second);
visit_AssociateBlock(*s);
}
if (is_a<ASR::Function_t>(*item.second)) {
ASR::Function_t *s = ASR::down_cast<ASR::Function_t>(item.second);
visit_Function(*s);
visit_Function(*ASR::down_cast<ASR::Function_t>(
item.second));
}
}

Expand Down
52 changes: 52 additions & 0 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,58 @@ namespace LCompilers {
}
}

static inline void handle_fn_return_var(Allocator &al, ASR::Function_t *x,
bool (*is_array_or_struct)(ASR::expr_t*)) {
if (x->m_return_var) {
/*
* The `return_var` of the function, which is either an array or
* struct, is set to `null`, and the destination variable will be
* passed as the last argument to the existing function. This helps
* in avoiding deep copies and the destination memory directly gets
* filled inside the function.
*/
if( is_array_or_struct(x->m_return_var)) {
for( auto& s_item: x->m_symtab->get_scope() ) {
ASR::symbol_t* curr_sym = s_item.second;
if( curr_sym->type == ASR::symbolType::Variable ) {
ASR::Variable_t* var = ASR::down_cast<ASR::Variable_t>(curr_sym);
if( var->m_intent == ASR::intentType::Unspecified ) {
var->m_intent = ASR::intentType::In;
} else if( var->m_intent == ASR::intentType::ReturnVar ) {
var->m_intent = ASR::intentType::Out;
}
}
}
Vec<ASR::expr_t*> a_args;
a_args.reserve(al, x->n_args + 1);
for( size_t i = 0; i < x->n_args; i++ ) {
a_args.push_back(al, x->m_args[i]);
}
LCOMPILERS_ASSERT(x->m_return_var)
a_args.push_back(al, x->m_return_var);
x->m_args = a_args.p;
x->n_args = a_args.n;
x->m_return_var = nullptr;
ASR::FunctionType_t* s_func_type = ASR::down_cast<ASR::FunctionType_t>(
x->m_function_signature);
Vec<ASR::ttype_t*> arg_types;
arg_types.reserve(al, a_args.n);
for(auto &e: a_args) {
arg_types.push_back(al, ASRUtils::expr_type(e));
}
s_func_type->m_arg_types = arg_types.p;
s_func_type->n_arg_types = arg_types.n;
s_func_type->m_return_var_type = nullptr;
}
}
for (auto &item : x->m_symtab->get_scope()) {
if (ASR::is_a<ASR::Function_t>(*item.second)) {
handle_fn_return_var(al, ASR::down_cast<ASR::Function_t>(
item.second), is_array_or_struct);
}
}
}

} // namespace PassUtils

} // namespace LCompilers
Expand Down
Loading

0 comments on commit baf30f7

Please sign in to comment.