Skip to content

Commit

Permalink
Merge pull request lcompilers#2326 from Shaikh-Ubaid/changes_from_lfo…
Browse files Browse the repository at this point in the history
…rtran

Changes from lfortran
  • Loading branch information
certik committed Sep 16, 2023
2 parents 825575a + 12a376a commit df83b5a
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,8 @@ def visitField(self, field):
self.emit(" self().replace_expr(x->m_%s[i]);"%(field.name), level)
self.emit(" current_expr = current_expr_copy_%d;" % (self.current_expr_copy_variable_count), level)
self.current_expr_copy_variable_count += 1
elif field.type == "ttype":
self.emit(" self().replace_%s(x->m_%s[i]);" % (field.type, field.name), level)
self.emit("}", level)
else:
if field.type != "symbol":
Expand Down
32 changes: 23 additions & 9 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3247,10 +3247,12 @@ class ReplaceWithFunctionParamVisitor: public ASR::BaseExprReplacer<ReplaceWithF

size_t n_args;

SymbolTable* current_scope;

public:

ReplaceWithFunctionParamVisitor(Allocator& al_, ASR::expr_t** m_args_, size_t n_args_) :
al(al_), m_args(m_args_), n_args(n_args_) {}
al(al_), m_args(m_args_), n_args(n_args_), current_scope(nullptr) {}

void replace_Var(ASR::Var_t* x) {
size_t arg_idx = 0;
Expand All @@ -3268,14 +3270,26 @@ class ReplaceWithFunctionParamVisitor: public ASR::BaseExprReplacer<ReplaceWithF
if( idx_found ) {
LCOMPILERS_ASSERT(current_expr);
ASR::ttype_t* t_ = replace_args_with_FunctionParam(
ASRUtils::symbol_type(x->m_v));
ASRUtils::symbol_type(x->m_v), current_scope);
*current_expr = ASRUtils::EXPR(ASR::make_FunctionParam_t(
al, m_args[arg_idx]->base.loc, arg_idx,
t_, nullptr));
}
}

ASR::ttype_t* replace_args_with_FunctionParam(ASR::ttype_t* t) {
void replace_Struct(ASR::Struct_t *x) {
std::string derived_type_name = ASRUtils::symbol_name(x->m_derived_type);
ASR::symbol_t* derived_type_sym = current_scope->resolve_symbol(derived_type_name);
LCOMPILERS_ASSERT_MSG( derived_type_sym != nullptr,
"derived_type_sym cannot be nullptr");
if (derived_type_sym != x->m_derived_type) {
x->m_derived_type = derived_type_sym;
}
}

ASR::ttype_t* replace_args_with_FunctionParam(ASR::ttype_t* t, SymbolTable* current_scope) {
this->current_scope = current_scope;

ASRUtils::ExprStmtDuplicator duplicator(al);
duplicator.allow_procedure_calls = true;

Expand Down Expand Up @@ -3312,21 +3326,21 @@ inline ASR::asr_t* make_FunctionType_t_util(Allocator &al,
ASR::expr_t* a_return_var, ASR::abiType a_abi, ASR::deftypeType a_deftype,
char* a_bindc_name, bool a_elemental, bool a_pure, bool a_module, bool a_inline,
bool a_static,
ASR::symbol_t** a_restrictions, size_t n_restrictions, bool a_is_restriction) {
ASR::symbol_t** a_restrictions, size_t n_restrictions, bool a_is_restriction, SymbolTable* current_scope) {
Vec<ASR::ttype_t*> arg_types;
arg_types.reserve(al, n_args);
ReplaceWithFunctionParamVisitor replacer(al, a_args, n_args);
for( size_t i = 0; i < n_args; i++ ) {
// We need to substitute all direct argument variable references with
// FunctionParam.
ASR::ttype_t *t = replacer.replace_args_with_FunctionParam(
expr_type(a_args[i]));
expr_type(a_args[i]), current_scope);
arg_types.push_back(al, t);
}
ASR::ttype_t* return_var_type = nullptr;
if( a_return_var ) {
return_var_type = replacer.replace_args_with_FunctionParam(
ASRUtils::expr_type(a_return_var));
ASRUtils::expr_type(a_return_var), current_scope);
}

LCOMPILERS_ASSERT(arg_types.size() == n_args);
Expand All @@ -3338,12 +3352,12 @@ inline ASR::asr_t* make_FunctionType_t_util(Allocator &al,
}

inline ASR::asr_t* make_FunctionType_t_util(Allocator &al, const Location &a_loc,
ASR::expr_t** a_args, size_t n_args, ASR::expr_t* a_return_var, ASR::FunctionType_t* ft) {
ASR::expr_t** a_args, size_t n_args, ASR::expr_t* a_return_var, ASR::FunctionType_t* ft, SymbolTable* current_scope) {
return ASRUtils::make_FunctionType_t_util(al, a_loc, a_args, n_args, a_return_var,
ft->m_abi, ft->m_deftype, ft->m_bindc_name, ft->m_elemental,
ft->m_pure, ft->m_module, ft->m_inline, ft->m_static,
ft->m_restrictions,
ft->n_restrictions, ft->m_is_restriction);
ft->n_restrictions, ft->m_is_restriction, current_scope);
}

inline ASR::asr_t* make_Function_t_util(Allocator& al, const Location& loc,
Expand All @@ -3357,7 +3371,7 @@ inline ASR::asr_t* make_Function_t_util(Allocator& al, const Location& loc,
ASR::ttype_t* func_type = ASRUtils::TYPE(ASRUtils::make_FunctionType_t_util(
al, loc, a_args, n_args, m_return_var, m_abi, m_deftype, m_bindc_name,
m_elemental, m_pure, m_module, m_inline, m_static,
m_restrictions, n_restrictions, m_is_restriction));
m_restrictions, n_restrictions, m_is_restriction, m_symtab));
return ASR::make_Function_t(
al, loc, m_symtab, m_name, func_type, m_dependencies, n_dependencies,
a_args, n_args, m_body, n_body, m_return_var, m_access, m_deterministic,
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/pass_array_by_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor<PassArrayB

ASR::FunctionType_t* func_type = ASRUtils::get_FunctionType(*x);
x->m_function_signature = ASRUtils::TYPE(ASRUtils::make_FunctionType_t_util(
al, func_type->base.base.loc, new_args.p, new_args.size(), x->m_return_var, func_type));
al, func_type->base.base.loc, new_args.p, new_args.size(), x->m_return_var, func_type, current_scope));
x->m_args = new_args.p;
x->n_args = new_args.size();
}
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ namespace LCompilers {
for(auto &e: a_args) {
ASRUtils::ReplaceWithFunctionParamVisitor replacer(al, x->m_args, x->n_args);
arg_types.push_back(al, replacer.replace_args_with_FunctionParam(
ASRUtils::expr_type(e)));
ASRUtils::expr_type(e), x->m_symtab));
}
s_func_type->m_arg_types = arg_types.p;
s_func_type->n_arg_types = arg_types.n;
Expand Down

0 comments on commit df83b5a

Please sign in to comment.