From ea36ba781336b111eb6efd97a5b4332bb8e02ea7 Mon Sep 17 00:00:00 2001 From: Harsh Singh Jadon Date: Sat, 8 Apr 2023 14:41:35 +0530 Subject: [PATCH 1/2] Support for dict as args in function --- integration_tests/CMakeLists.txt | 3 +++ integration_tests/test_dict_08.py | 31 ++++++++++++++++++++++++++++++ integration_tests/test_dict_09.py | 29 ++++++++++++++++++++++++++++ integration_tests/test_dict_10.py | 20 +++++++++++++++++++ src/libasr/codegen/asr_to_llvm.cpp | 26 +++++++++++++++++++++++++ 5 files changed, 109 insertions(+) create mode 100644 integration_tests/test_dict_08.py create mode 100644 integration_tests/test_dict_09.py create mode 100644 integration_tests/test_dict_10.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index edead131190..7f4d525db50 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -290,6 +290,9 @@ RUN(NAME test_dict_04 LABELS cpython llvm) RUN(NAME test_dict_05 LABELS cpython llvm) RUN(NAME test_dict_06 LABELS cpython llvm c) RUN(NAME test_dict_07 LABELS cpython llvm) +RUN(NAME test_dict_08 LABELS cpython llvm c) +RUN(NAME test_dict_09 LABELS cpython llvm c) +RUN(NAME test_dict_10 LABELS cpython llvm) # TODO: Add support of dict with string in C backend RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_dict_08.py b/integration_tests/test_dict_08.py new file mode 100644 index 00000000000..18458494588 --- /dev/null +++ b/integration_tests/test_dict_08.py @@ -0,0 +1,31 @@ +# test case for passing dict as args and return value to a function + +from lpython import i32 + +def get_cubes_from_squares(squares: dict[i32, i32]) -> dict[i32, i32]: + i : i32 + cubes: dict[i32, i32] = {} + for i in range(1, 16): + cubes[i] = squares[i] * i + return cubes + +def assert_dict(squares: dict[i32, i32], cubes: dict[i32, i32]): + i : i32 + for i in range(1, 16): + assert squares[i] == (i * i) + for i in range(1, 16): + assert cubes[i] == (i * i * i) + assert len(squares) == 15 + assert len(cubes) == 15 + +def test_dict(): + squares : dict[i32, i32] + squares = {1:1, 2:4, 3:9, 4:16, 5:25, 6:36, 7:49, + 8:64, 9:81, 10:100, 11:121, 12:144, 13:169, + 14: 196, 15:225} + + cubes : dict[i32, i32] + cubes = get_cubes_from_squares(squares) + assert_dict(squares, cubes) + +test_dict() diff --git a/integration_tests/test_dict_09.py b/integration_tests/test_dict_09.py new file mode 100644 index 00000000000..412fce9aa85 --- /dev/null +++ b/integration_tests/test_dict_09.py @@ -0,0 +1,29 @@ +# test case for passing dict as args and return value to a function + +from lpython import f64, i32 + +def fill_rollnumber2cpi(size: i32) -> dict[i32, f64]: + i : i32 + rollnumber2cpi: dict[i32, f64] = {} + + rollnumber2cpi[0] = 1.1 + for i in range(1000, 1000 + size): + rollnumber2cpi[i] = float(i/100.0 + 5.0) + + return rollnumber2cpi + +def test_assertion(rollnumber2cpi: dict[i32, f64], size: i32): + i: i32 + for i in range(1000 + size - 1, 1001, -1): + assert abs(rollnumber2cpi[i] - i/100.0 - 5.0) <= 1e-12 + + assert abs(rollnumber2cpi[0] - 1.1) <= 1e-12 + assert len(rollnumber2cpi) == 201 + +def test_dict(): + size: i32 = 200 + rollnumber2cpi: dict[i32, f64] = fill_rollnumber2cpi(size) + + test_assertion(rollnumber2cpi, size) + +test_dict() diff --git a/integration_tests/test_dict_10.py b/integration_tests/test_dict_10.py new file mode 100644 index 00000000000..a4d026fc492 --- /dev/null +++ b/integration_tests/test_dict_10.py @@ -0,0 +1,20 @@ +# test case for passing dict with key-value as strings as argument to function + +def test_assertion(smalltocaps: dict[str, str]): + i : i32 + assert len(smalltocaps) == 26 + for i in range(97, 97 + 26): + assert smalltocaps[chr(i)] == chr(i - 32) + +def test_dict(): + smalltocaps: dict[str, str] + smalltocaps = {'a': 'A', 'b': 'B', 'c': 'C', 'd': 'D','e': 'E', + 'f': 'F', 'g': 'G', 'h': 'H', 'i': 'I','j': 'J', + 'k': 'K', 'l': 'L', 'm': 'M', 'n': 'N','o': 'O', + 'p': 'P', 'q': 'Q', 'r': 'R', 's': 'S','t': 'T', + 'u': 'U', 'v': 'V', 'w': 'W', 'x': 'X','y': 'Y', + 'z': 'Z'} + + test_assertion(smalltocaps) + +test_dict() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index c01450f288e..883753c8bc2 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -3641,6 +3641,32 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } break ; } + case (ASR::ttypeType::Dict): { + ASR::Dict_t* asr_dict = ASR::down_cast(asr_type); + std::string key_type_code = ASRUtils::get_type_code(asr_dict->m_key_type); + std::string value_type_code = ASRUtils::get_type_code(asr_dict->m_value_type); + + bool is_array_type = false, is_malloc_array_type = false; + bool is_list = false; + ASR::dimension_t* m_dims = nullptr; + llvm::Type* key_llvm_type = get_type_from_ttype_t(asr_dict->m_key_type, m_storage, + is_array_type, + is_malloc_array_type, + is_list, m_dims, n_dims, + a_kind, m_abi); + llvm::Type* value_llvm_type = get_type_from_ttype_t(asr_dict->m_value_type, m_storage, + is_array_type, + is_malloc_array_type, + is_list, m_dims, n_dims, + a_kind, m_abi); + int32_t key_type_size = get_type_size(asr_dict->m_key_type, key_llvm_type, a_kind); + int32_t value_type_size = get_type_size(asr_dict->m_value_type, value_llvm_type, a_kind); + set_dict_api(asr_dict); + type = llvm_utils->dict_api->get_dict_type(key_type_code, value_type_code, + key_type_size, value_type_size, + key_llvm_type, value_llvm_type)->getPointerTo(); + break; + } default : LCOMPILERS_ASSERT(false); } From 522581d4aacf5a7119d4f8b8691e7662542fd2eb Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Sun, 9 Apr 2023 20:15:07 +0530 Subject: [PATCH 2/2] Redesign pass_array_by_data.cpp for enhancing its robustness --- .gitignore | 1 + src/libasr/asr_utils.h | 24 +- src/libasr/pass/pass_array_by_data.cpp | 507 ++++++++++++------------- src/libasr/pass/pass_utils.h | 8 + 4 files changed, 263 insertions(+), 277 deletions(-) diff --git a/.gitignore b/.gitignore index 991469178c5..d9a4892a9b1 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ CPackConfig.cmake CPackSourceConfig.cmake _CPack_Packages /CMakeSettings.json +src/libasr/libasr.a.* ## libraries *.a diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index c2166f8b15f..b48b95468de 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -2801,19 +2801,14 @@ class SymbolDuplicator { new_args.push_back(al, new_arg); } - node_duplicator.success = true; - ASR::expr_t* new_return_var = node_duplicator.duplicate_expr(function->m_return_var); - if (ASR::is_a(*new_return_var)) { - ASR::Var_t* var = ASR::down_cast(new_return_var); - if (ASR::is_a(*(var->m_v))) { - ASR::Variable_t* variable = ASR::down_cast(var->m_v); - ASR::symbol_t* arg_symbol = function_symtab->get_symbol(variable->m_name); - new_return_var = ASRUtils::EXPR(make_Var_t(al, var->base.base.loc, arg_symbol)); + ASR::expr_t* new_return_var = function->m_return_var; + if( new_return_var ) { + node_duplicator.success = true; + new_return_var = node_duplicator.duplicate_expr(function->m_return_var); + if( !node_duplicator.success ) { + return nullptr; } } - if( !node_duplicator.success ) { - return nullptr; - } ASR::FunctionType_t* function_type = ASRUtils::get_FunctionType(function); @@ -3084,6 +3079,9 @@ static inline ASR::expr_t* compute_length_from_start_end(Allocator& al, ASR::exp } static inline bool is_pass_array_by_data_possible(ASR::Function_t* x, std::vector& v) { + // BindC interfaces already pass array by data pointer so we don't need to track + // them and use extra variables for their dimensional information. Only those functions + // need to be tracked which by default pass arrays by using descriptors. if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindC && ASRUtils::get_FunctionType(x)->m_deftype == ASR::deftypeType::Interface) { return false; @@ -3109,6 +3107,10 @@ static inline bool is_pass_array_by_data_possible(ASR::Function_t* x, std::vecto if( ASR::is_a(*argi->m_type) ) { return false; } + + // The following if check determines whether the i-th argument + // can be called by just passing the data pointer and + // dimensional information spearately via extra arguments. if( ASRUtils::is_dimension_empty(dims, n_dims) && (argi->m_intent == ASRUtils::intent_in || argi->m_intent == ASRUtils::intent_out || diff --git a/src/libasr/pass/pass_array_by_data.cpp b/src/libasr/pass/pass_array_by_data.cpp index 788133c240e..d3f1e82c485 100644 --- a/src/libasr/pass/pass_array_by_data.cpp +++ b/src/libasr/pass/pass_array_by_data.cpp @@ -8,6 +8,37 @@ #include #include +#include + +/* +This ASR to ASR pass can be called whenever you want to avoid +using descriptors for passing arrays to functions/subroutines +in the backend. By default, it is currently being used by all +the backends (via pass manager). Backends like WASM which +do not support array descriptors always need this pass +to be called. + +The possibility to avoid descriptors and pass array by data +to functions/subroutines is determined by ASRUtils::is_pass_array_by_data_possible +defined in asr_utils.h. + +Advantages and dis-advantages of this pass are as follows, + +Advantages: + +* Avoiding array descriptors and just using simple data points leads to +easier handling in the backend. + +* A lot of indirection to access dimensional information is also avoided +because it is already provided via extra variable arguments. + +Dis-advantages: + +* Requires access to all the code, as function interfaces have to be modified. Hence, not always possible. + +* The arrays become contiguous, which most of the time is fine, but sometimes you might lose performance. + +*/ namespace LCompilers { @@ -35,82 +66,16 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor> > proc2newproc; + std::set newprocs; PassArrayByDataProcedureVisitor(Allocator& al_) : PassVisitor(al_, nullptr), - node_duplicator(al_), current_proc_scope(nullptr), is_editing_procedure(false) + node_duplicator(al_) {} - void visit_Var(const ASR::Var_t& x) { - if( !is_editing_procedure ) { - return ; - } - ASR::Var_t& xx = const_cast(x); - ASR::symbol_t* x_sym = xx.m_v; - SymbolTable* x_sym_symtab = ASRUtils::symbol_parent_symtab(x_sym); - if( x_sym_symtab->get_counter() != current_proc_scope->get_counter() && - !ASRUtils::is_parent(x_sym_symtab, current_proc_scope) ) { - // xx.m_v points to the function/procedure present inside - // original function's symtab. Make it point to the symbol in - // new function's symtab. - std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym)); - xx.m_v = current_proc_scope->resolve_symbol(x_sym_name); - LCOMPILERS_ASSERT(xx.m_v != nullptr); - } - } - - void visit_BlockCall(const ASR::BlockCall_t& x) { - if( !is_editing_procedure ) { - return ; - } - ASR::BlockCall_t& xx = const_cast(x); - ASR::symbol_t* x_sym = xx.m_m; - SymbolTable* x_sym_symtab = ASRUtils::symbol_parent_symtab(x_sym); - if( x_sym_symtab->get_counter() != current_proc_scope->get_counter() && - !ASRUtils::is_parent(x_sym_symtab, current_proc_scope) ) { - // xx.m_v points to the function/procedure present inside - // original function's symtab. Make it point to the symbol in - // new function's symtab. - std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym)); - xx.m_m = current_proc_scope->resolve_symbol(x_sym_name); - LCOMPILERS_ASSERT(xx.m_m != nullptr); - } - } - - void visit_Call(ASR::symbol_t*& m_name) { - if( !is_editing_procedure ) { - return ; - } - ASR::symbol_t* x_sym = m_name; - SymbolTable* x_sym_symtab = ASRUtils::symbol_parent_symtab(x_sym); - if( x_sym_symtab->get_counter() != current_proc_scope->get_counter() && - !ASRUtils::is_parent(x_sym_symtab, current_proc_scope) ) { - // xx.m_v points to the function/procedure present inside - // original function's symtab. Make it point to the symbol in - // new function's symtab. - std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym)); - m_name = current_proc_scope->resolve_symbol(x_sym_name); - LCOMPILERS_ASSERT(m_name != nullptr); - } - } - - void visit_FunctionCall(const ASR::FunctionCall_t& x) { - ASR::FunctionCall_t& xx = const_cast(x); - visit_Call(xx.m_name); - PassUtils::PassVisitor::visit_FunctionCall(x); - } - - void visit_SubroutineCall(const ASR::SubroutineCall_t& x) { - ASR::SubroutineCall_t& xx = const_cast(x); - visit_Call(xx.m_name); - PassUtils::PassVisitor::visit_SubroutineCall(x); - } - ASR::symbol_t* insert_new_procedure(ASR::Function_t* x, std::vector& indices) { Vec new_body; new_body.reserve(al, x->n_body); @@ -195,22 +160,11 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitoradd_symbol(new_name, new_symbol); proc2newproc[(ASR::symbol_t*) x] = std::make_pair(new_symbol, indices); + newprocs.insert(new_symbol); return new_symbol; } - void visit_Block(const ASR::Block_t& x) { - SymbolTable* current_proc_scope_copy = current_proc_scope; - current_proc_scope = x.m_symtab; - for( auto itr: x.m_symtab->get_scope() ) { - visit_symbol(*itr.second); - } - for( size_t i = 0; i < x.n_body; i++ ) { - visit_stmt(*x.m_body[i]); - } - current_proc_scope = current_proc_scope_copy; - } - - void edit_new_procedure(ASR::Function_t* x, std::vector& indices) { + void edit_new_procedure_args(ASR::Function_t* x, std::vector& indices) { Vec new_args; new_args.reserve(al, x->n_args); for( size_t i = 0; i < x->n_args; i++ ) { @@ -223,7 +177,7 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor dim_variables; std::string arg_name = std::string(arg->m_name); PassUtils::create_vars(dim_variables, 2 * n_dims, arg->base.base.loc, al, - x->m_symtab, arg_name, ASR::intentType::In, arg->m_presence); + x->m_symtab, arg_name, ASR::intentType::In, arg->m_presence); Vec new_dims; new_dims.reserve(al, n_dims); for( int j = 0, k = 0; j < n_dims; j++ ) { @@ -244,25 +198,6 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitorm_args = new_args.p; x->n_args = new_args.size(); - - is_editing_procedure = true; - current_proc_scope = x->m_symtab; - for( auto& itr: x->m_symtab->get_scope() ) { - if( ASR::is_a(*itr.second) ) { - PassVisitor::visit_ttype(*ASR::down_cast(itr.second)->m_type); - } else if( ASR::is_a(*itr.second) || - ASR::is_a(*itr.second) ) { - SymbolTable* current_proc_scope_copy = current_proc_scope; - current_proc_scope = ASRUtils::symbol_symtab(itr.second); - visit_symbol(*itr.second); - current_proc_scope = current_proc_scope_copy; - } - } - for( size_t i = 0; i < x->n_body; i++ ) { - visit_stmt(*x->m_body[i]); - } - is_editing_procedure = false; - current_proc_scope = nullptr; } void visit_TranslationUnit(const ASR::TranslationUnit_t& x) { @@ -282,7 +217,8 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor - void visit_SymbolContainingFunctions(const T& x) { + bool visit_SymbolContainingFunctions(const T& x, + std::deque& pass_array_by_data_functions) { T& xx = const_cast(x); current_scope = xx.m_symtab; for( auto& item: xx.m_symtab->get_scope() ) { @@ -293,15 +229,25 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor(sym); - edit_new_procedure(new_subrout, arg_indices); + edit_new_procedure_args(new_subrout, arg_indices); + pass_array_by_data_functions.push_back(new_subrout); } } } } + return pass_array_by_data_functions.size() > 0; } + #define bfs_visit_SymbolContainingFunctions() std::deque pass_array_by_data_functions; \ + visit_SymbolContainingFunctions(x, pass_array_by_data_functions); \ + while( pass_array_by_data_functions.size() > 0 ) { \ + ASR::Function_t* function = pass_array_by_data_functions.front(); \ + pass_array_by_data_functions.pop_front(); \ + visit_SymbolContainingFunctions(*function, pass_array_by_data_functions); \ + } \ + void visit_Program(const ASR::Program_t& x) { - visit_SymbolContainingFunctions(x); + bfs_visit_SymbolContainingFunctions() } void visit_Module(const ASR::Module_t& x) { @@ -310,45 +256,150 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitor { + + public: + + PassArrayByDataProcedureVisitor& v; + + EditProcedureVisitor(PassArrayByDataProcedureVisitor& v_): + v(v_) {} + + void visit_Function(const ASR::Function_t &x) { + ASR::Function_t& xx = const_cast(x); + SymbolTable* current_scope_copy = current_scope; + current_scope = x.m_symtab; + for (auto &a : x.m_symtab->get_scope()) { + this->visit_symbol(*a.second); + } + + // See integration_tests/modules_26.f90 + // for the reason of commenting out + // the following line + // visit_ttype(*x.m_function_signature); + + for (size_t i=0; iget_counter() != current_scope->get_counter() && \ + !ASRUtils::is_parent(x_sym_symtab, current_scope) ) { \ + std::string x_sym_name = std::string(ASRUtils::symbol_name(x_sym)); \ + xx.m_##attr = current_scope->resolve_symbol(x_sym_name); \ + LCOMPILERS_ASSERT(xx.m_##attr != nullptr); \ + } \ + + void visit_Var(const ASR::Var_t& x) { + ASR::Var_t& xx = const_cast(x); + ASR::symbol_t* x_sym_ = xx.m_v; + if ( v.proc2newproc.find(x_sym_) != v.proc2newproc.end() ) { + xx.m_v = v.proc2newproc[x_sym_].first; + return ; + } + + edit_symbol(v) + } + + void visit_BlockCall(const ASR::BlockCall_t& x) { + ASR::BlockCall_t& xx = const_cast(x); + edit_symbol(m) + } + + void visit_FunctionCall(const ASR::FunctionCall_t& x) { + ASR::FunctionCall_t& xx = const_cast(x); + edit_symbol(name) + ASR::ASRPassBaseWalkVisitor::visit_FunctionCall(x); + } + + void visit_SubroutineCall(const ASR::SubroutineCall_t& x) { + ASR::SubroutineCall_t& xx = const_cast(x); + edit_symbol(name) + ASR::ASRPassBaseWalkVisitor::visit_SubroutineCall(x); + } + +}; + /* - The following visitor replaces subroutine calls with arrays as arguments - to subroutine calls having dimensional information passed as arguments. See example below, + The following visitor replaces procedure calls with arrays as arguments + to procedure calls having dimensional information passed as arguments. See example below, - call f(array1, array2) + call f1(array1, array2) + sum = f(array) + g(array) gets converted to, - call f_array1_array2(array1, m1, n1, array2, m2, n2) + call f1_array1_array2(array1, m1, n1, array2, m2, n2) + sum = f_array(array, m, n) + g_array(array, m, n) As can be seen dimensional information, m1, n1 is passed along with array1 and similarly m2, n2 is passed along with array2. */ -class ReplaceSubroutineCallsVisitor : public PassUtils::PassVisitor +class EditProcedureCallsVisitor : public ASR::ASRPassBaseWalkVisitor { private: + Allocator& al; PassArrayByDataProcedureVisitor& v; public: - ReplaceSubroutineCallsVisitor(Allocator& al_, PassArrayByDataProcedureVisitor& v_): PassVisitor(al_, nullptr), - v(v_) { - pass_result.reserve(al, 1); - } + EditProcedureCallsVisitor(Allocator& al_, + PassArrayByDataProcedureVisitor& v_): + al(al_), v(v_) {} - void visit_SubroutineCall(const ASR::SubroutineCall_t& x) { + template + void visit_Call(const T& x) { ASR::symbol_t* subrout_sym = x.m_name; bool is_external = ASR::is_a(*subrout_sym); subrout_sym = ASRUtils::symbol_get_past_external(subrout_sym); if( v.proc2newproc.find(subrout_sym) == v.proc2newproc.end() ) { + bool args_updated = false; + Vec new_args; + new_args.reserve(al, x.n_args); + for ( size_t i = 0; i < x.n_args; i++ ) { + ASR::call_arg_t arg = x.m_args[i]; + ASR::expr_t* expr = arg.m_value; + bool use_original_arg = true; + if (expr) { + if (ASR::is_a(*expr)) { + ASR::Var_t* var = ASR::down_cast(expr); + ASR::symbol_t* sym = var->m_v; + if ( v.proc2newproc.find(sym) != v.proc2newproc.end() ) { + ASR::symbol_t* new_var_sym = v.proc2newproc[sym].first; + ASR::expr_t* new_var = ASRUtils::EXPR(ASR::make_Var_t(al, var->base.base.loc, new_var_sym)); + ASR::call_arg_t new_arg; + new_arg.m_value = new_var; + new_arg.loc = arg.loc; + new_args.push_back(al, new_arg); + args_updated = true; + use_original_arg = false; + } + } + } + if( use_original_arg ) { + new_args.push_back(al, arg); + } + } + if (args_updated) { + T&xx = const_cast(x); + xx.m_args = new_args.p; + xx.n_args = new_args.size(); + } return ; } - ASR::symbol_t* new_subrout_sym = v.proc2newproc[subrout_sym].first; + ASR::symbol_t* new_func_sym = v.proc2newproc[subrout_sym].first; std::vector& indices = v.proc2newproc[subrout_sym].second; Vec new_args; @@ -356,7 +407,7 @@ class ReplaceSubroutineCallsVisitor : public PassUtils::PassVisitor(new_func_sym); + size_t min_args = 0, max_args = 0; + for( size_t i = 0; i < new_func_->n_args; i++ ) { + ASR::Var_t* arg = ASR::down_cast(new_func_->m_args[i]); + if( ASR::is_a(*arg->m_v) && + ASR::down_cast(arg->m_v)->m_presence + == ASR::presenceType::Optional ) { + max_args += 1; + } else { + min_args += 1; + max_args += 1; + } + } + if( !(min_args <= new_args.size() && + new_args.size() <= max_args) ) { + throw LCompilersException("Number of arguments in the new " + "function call doesn't satisfy " + "min_args <= new_args.size() <= max_args, " + + std::to_string(min_args) + " <= " + + std::to_string(new_args.size()) + " <= " + + std::to_string(max_args)); + } + } + ASR::symbol_t* new_func_sym_ = new_func_sym; if( is_external ) { - ASR::ExternalSymbol_t* subrout_ext_sym = ASR::down_cast(x.m_name); + ASR::ExternalSymbol_t* func_ext_sym = ASR::down_cast(x.m_name); // TODO: Use SymbolTable::get_unique_name to avoid potential // clashes with user defined functions - char* new_subrout_sym_name = ASRUtils::symbol_name(new_subrout_sym); - if( current_scope->get_symbol(new_subrout_sym_name) == nullptr ) { - new_subrout_sym_ = ASR::down_cast( - ASR::make_ExternalSymbol_t(al, x.m_name->base.loc, current_scope, - new_subrout_sym_name, new_subrout_sym, subrout_ext_sym->m_module_name, - subrout_ext_sym->m_scope_names, subrout_ext_sym->n_scope_names, new_subrout_sym_name, - subrout_ext_sym->m_access)); - current_scope->add_symbol(new_subrout_sym_name, new_subrout_sym_); + char* new_func_sym_name = ASRUtils::symbol_name(new_func_sym); + if( current_scope->get_symbol(new_func_sym_name) == nullptr ) { + new_func_sym_ = ASR::down_cast( + ASR::make_ExternalSymbol_t(al, x.m_name->base.loc, func_ext_sym->m_parent_symtab, + new_func_sym_name, new_func_sym, func_ext_sym->m_module_name, + func_ext_sym->m_scope_names, func_ext_sym->n_scope_names, new_func_sym_name, + func_ext_sym->m_access)); + func_ext_sym->m_parent_symtab->add_symbol(new_func_sym_name, new_func_sym_); } else { - new_subrout_sym_ = current_scope->get_symbol(new_subrout_sym_name); + new_func_sym_ = current_scope->resolve_symbol(new_func_sym_name); } - LCOMPILERS_ASSERT(ASR::is_a(*new_subrout_sym_)); - } - ASR::stmt_t* new_call = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, - x.base.base.loc, new_subrout_sym_, new_subrout_sym_, - new_args.p, new_args.size(), x.m_dt)); - pass_result.push_back(al, new_call); - } -}; - - -/* - -The following replacer replaces all the function call expressions with arrays -as arguments to function call expressions having dimensional information of -array arguments passed along. See example below, - - sum = f(array) + g(array) - -gets converted to, - - sum = f_array(array, m, n) + g_array(array, m, n) - -*/ -class ReplaceFunctionCalls: public ASR::BaseExprReplacer { - - private: - - Allocator& al; - PassArrayByDataProcedureVisitor& v; - - public: - - SymbolTable* current_scope; - - ReplaceFunctionCalls(Allocator& al_, PassArrayByDataProcedureVisitor& v_) : al(al_), v(v_) - {} - - void replace_FunctionCall(ASR::FunctionCall_t* x) { - ASR::symbol_t* subrout_sym = x->m_name; - bool is_external = ASR::is_a(*subrout_sym); - subrout_sym = ASRUtils::symbol_get_past_external(subrout_sym); - if( v.proc2newproc.find(subrout_sym) == v.proc2newproc.end() ) { - return ; - } - - ASR::symbol_t* new_func_sym = v.proc2newproc[subrout_sym].first; - std::vector& indices = v.proc2newproc[subrout_sym].second; - - Vec new_args; - new_args.reserve(al, x->n_args); - for( size_t i = 0; i < x->n_args; i++ ) { - new_args.push_back(al, x->m_args[i]); - if( std::find(indices.begin(), indices.end(), i) == indices.end() || - x->m_args[i].m_value == nullptr ) { - continue ; - } - - Vec dim_vars; - dim_vars.reserve(al, 2); - ASRUtils::get_dimensions(x->m_args[i].m_value, dim_vars, al); - for( size_t j = 0; j < dim_vars.size(); j++ ) { - ASR::call_arg_t dim_var; - dim_var.loc = dim_vars[j]->base.loc; - dim_var.m_value = dim_vars[j]; - new_args.push_back(al, dim_var); } + T& xx = const_cast(x); + xx.m_name = new_func_sym_; + xx.m_original_name = new_func_sym_; + xx.m_args = new_args.p; + xx.n_args = new_args.size(); } - { - ASR::Function_t* new_func_ = ASR::down_cast(new_func_sym); - size_t min_args = 0, max_args = 0; - for( size_t i = 0; i < new_func_->n_args; i++ ) { - ASR::Var_t* arg = ASR::down_cast(new_func_->m_args[i]); - if( ASR::is_a(*arg->m_v) && - ASR::down_cast(arg->m_v)->m_presence - == ASR::presenceType::Optional ) { - max_args += 1; - } else { - min_args += 1; - max_args += 1; - } - } - if( !(min_args <= new_args.size() && - new_args.size() <= max_args) ) { - throw LCompilersException("Number of arguments in the new " - "function call doesn't satisfy " - "min_args <= new_args.size() <= max_args, " + - std::to_string(min_args) + " <= " + - std::to_string(new_args.size()) + " <= " + - std::to_string(max_args)); - } - } - ASR::symbol_t* new_func_sym_ = new_func_sym; - if( is_external ) { - ASR::ExternalSymbol_t* func_ext_sym = ASR::down_cast(x->m_name); - // TODO: Use SymbolTable::get_unique_name to avoid potential - // clashes with user defined functions - char* new_func_sym_name = ASRUtils::symbol_name(new_func_sym); - if( current_scope->get_symbol(new_func_sym_name) == nullptr ) { - new_func_sym_ = ASR::down_cast( - ASR::make_ExternalSymbol_t(al, x->m_name->base.loc, func_ext_sym->m_parent_symtab, - new_func_sym_name, new_func_sym, func_ext_sym->m_module_name, - func_ext_sym->m_scope_names, func_ext_sym->n_scope_names, new_func_sym_name, - func_ext_sym->m_access)); - current_scope->add_symbol(new_func_sym_name, new_func_sym_); - } else { - new_func_sym_ = current_scope->get_symbol(new_func_sym_name); - } + void visit_SubroutineCall(const ASR::SubroutineCall_t& x) { + visit_Call(x); + ASR::ASRPassBaseWalkVisitor::visit_SubroutineCall(x); } - ASR::expr_t* new_call = ASRUtils::EXPR(ASR::make_FunctionCall_t(al, - x->base.base.loc, new_func_sym_, new_func_sym_, - new_args.p, new_args.size(), x->m_type, nullptr, - x->m_dt)); - *current_expr = new_call; - } - -}; - -/* -The following visitor calls the above replacer i.e., ReplaceFunctionCalls -on expressions present in ASR so that FunctionCall get replaced everywhere -and we don't end up with false positives. -*/ -class ReplaceFunctionCallsVisitor : public ASR::CallReplacerOnExpressionsVisitor -{ - private: - - ReplaceFunctionCalls replacer; - - public: - - ReplaceFunctionCallsVisitor(Allocator& al_, - PassArrayByDataProcedureVisitor& v_) : replacer(al_, v_) {} - void call_replacer() { - replacer.current_expr = current_expr; - replacer.current_scope = current_scope; - replacer.replace_expr(*current_expr); + void visit_FunctionCall(const ASR::FunctionCall_t& x) { + visit_Call(x); + ASR::ASRPassBaseWalkVisitor::visit_FunctionCall(x); } - }; /* @@ -548,6 +500,11 @@ class RemoveArrayByDescriptorProceduresVisitor : public PassUtils::PassVisitor(x); current_scope = xx.m_symtab; @@ -566,17 +523,35 @@ class RemoveArrayByDescriptorProceduresVisitor : public PassUtils::PassVisitor(x); + current_scope = xx.m_symtab; + + std::vector to_be_erased; + + for( auto& item: current_scope->get_scope() ) { + if( v.proc2newproc.find(item.second) != v.proc2newproc.end() ) { + LCOMPILERS_ASSERT(item.first == ASRUtils::symbol_name(item.second)) + to_be_erased.push_back(item.first); + } + } + + for (auto &item: to_be_erased) { + current_scope->erase_symbol(item); + } + } + }; void pass_array_by_data(Allocator &al, ASR::TranslationUnit_t &unit, const LCompilers::PassOptions& /*pass_options*/) { PassArrayByDataProcedureVisitor v(al); v.visit_TranslationUnit(unit); - ReplaceSubroutineCallsVisitor u(al, v); + EditProcedureVisitor e(v); + e.visit_TranslationUnit(unit); + EditProcedureCallsVisitor u(al, v); u.visit_TranslationUnit(unit); - ReplaceFunctionCallsVisitor w(al, v); - w.visit_TranslationUnit(unit); - RemoveArrayByDescriptorProceduresVisitor x(al,v); + RemoveArrayByDescriptorProceduresVisitor x(al, v); x.visit_TranslationUnit(unit); PassUtils::UpdateDependenciesVisitor y(al); y.visit_TranslationUnit(unit); diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 4604b4c80ad..0062190c6e9 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -193,8 +193,16 @@ namespace LCompilers { ASR::Function_t &xx = const_cast(x); SymbolTable* current_scope_copy = this->current_scope; this->current_scope = xx.m_symtab; + self().visit_ttype(*x.m_function_signature); + for (size_t i=0; iget_scope()) { if (ASR::is_a(*item.second)) { ASR::Function_t *s = ASR::down_cast(item.second);