From 2be0570ffedfd635651bb2ff155b1ebfa44c8fa0 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Fri, 3 Mar 2023 20:31:38 +0530 Subject: [PATCH] C: Generalize handing of extra codegen --- src/libasr/codegen/asr_to_c.cpp | 19 ++-- src/libasr/codegen/asr_to_c_cpp.h | 148 +++++++++++++++--------------- 2 files changed, 81 insertions(+), 86 deletions(-) diff --git a/src/libasr/codegen/asr_to_c.cpp b/src/libasr/codegen/asr_to_c.cpp index a19c4b1090..95a42f4979 100644 --- a/src/libasr/codegen/asr_to_c.cpp +++ b/src/libasr/codegen/asr_to_c.cpp @@ -1016,7 +1016,7 @@ R"( void visit_Assert(const ASR::Assert_t &x) { std::string indent(indentation_level*indentation_spaces, ' '); std::string out = indent; - tmp_src.clear(); + bracket_open++; if (x.m_msg) { out += "ASSERT_MSG("; visit_expr(*x.m_test); @@ -1028,12 +1028,8 @@ R"( visit_expr(*x.m_test); out += src + ");\n"; } - src = ""; - if (!tmp_src.empty()) { - for (auto &s: tmp_src) src += s; - } - src += out; - tmp_src.clear(); + bracket_open--; + src = check_tmp_buffer() + out; } void visit_CPtrToPointer(const ASR::CPtrToPointer_t& x) { @@ -1069,6 +1065,7 @@ R"( void visit_Print(const ASR::Print_t &x) { std::string indent(indentation_level*indentation_spaces, ' '); std::string tmp_gen = indent + "printf(\"", out = ""; + bracket_open++; std::vector v; std::string separator; if (x.m_separator) { @@ -1093,11 +1090,6 @@ R"( } tmp_gen += ");\n"; out += tmp_gen; - if (ASR::is_a(*x.m_values[i]) || - ASR::is_a(*x.m_values[i])) { - out += src; - src = const_var_names[get_hash((ASR::asr_t*)x.m_values[i])]; - } tmp_gen = indent + "printf(\""; v.clear(); std::string p_func = c_ds_api->get_print_func(value_type); @@ -1129,8 +1121,9 @@ R"( } } tmp_gen += ");\n"; + bracket_open--; out += tmp_gen; - src = out; + src = this->check_tmp_buffer() + out; } void visit_ArraySize(const ASR::ArraySize_t& x) { diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index 87ea4c36fd..e1fa5b85fd 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -103,7 +103,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor bool gen_stdcomplex; bool is_c; std::set headers; - std::vector tmp_src; + std::vector tmp_buffer_src; SymbolTable* global_scope; int64_t lower_bound; @@ -116,6 +116,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor std::string const_name; size_t const_vars_count; size_t loop_end_count; + int bracket_open; SymbolTable* current_scope; bool is_string_concat_present; @@ -128,7 +129,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor is_c{is_c}, global_scope{nullptr}, lower_bound{default_lower_bound}, template_number{0}, c_ds_api{std::make_unique(is_c, platform)}, const_name{"constname"}, - const_vars_count{0}, loop_end_count{0}, + const_vars_count{0}, loop_end_count{0}, bracket_open{0}, is_string_concat_present{false} { } @@ -198,6 +199,15 @@ R"(#include src = unit_src; } + std::string check_tmp_buffer() { + std::string ret = ""; + if (bracket_open == 0 && !tmp_buffer_src.empty()) { + for (auto &s: tmp_buffer_src) ret += s; + tmp_buffer_src.clear(); + } + return ret; + } + void visit_Module(const ASR::Module_t &x) { if (startswith(x.m_name, "lfortran_intrinsic_")) { intrinsic_module = true; @@ -319,6 +329,7 @@ R"(#include self().visit_stmt(*block->m_body[i]); body += src; } + decl += check_tmp_buffer(); src = open_paranthesis + decl + body + close_paranthesis; indentation_level -= 1; } @@ -406,6 +417,7 @@ R"(#include sym_name = "_xx_lcompilers_changed_exit_xx"; } std::string func = static_attr + inl + sub + sym_name + "("; + bracket_open++; for (size_t i=0; im_intent)); @@ -422,6 +434,7 @@ R"(#include if (i < x.n_args-1) func += ", "; } func += ")"; + bracket_open--; if( is_c || template_for_Kokkos.empty() ) { return func; } @@ -513,6 +526,7 @@ R"(#include self().visit_stmt(*x.m_body[i]); current_body += src; } + decl += check_tmp_buffer(); current_function = nullptr; bool visited_return = false; @@ -573,11 +587,13 @@ R"(#include } } else { std::string args; + bracket_open++; for (size_t i=0; i const_vars_count += 1; const_name = current_scope->get_unique_name(const_name); std::string indent(indentation_level*indentation_spaces, ' '); - current_body += indent + c_ds_api->get_list_type(list_type) + " " + - const_name + " = " + src + ";\n"; + tmp_buffer_src.push_back(check_tmp_buffer() + indent + c_ds_api->get_list_type(list_type) + " " + + const_name + " = " + src + ";\n"); src = const_name; + return; } + src = check_tmp_buffer() + src; } void visit_SizeOfType(const ASR::SizeOfType_t& x) { @@ -699,7 +717,6 @@ R"(#include ASR::TupleConstant_t *tup_const = ASR::down_cast(x.m_value); self().visit_TupleConstant(*tup_const); val_name = const_var_names[get_hash((ASR::asr_t*)tup_const)]; - src_tmp += src; } else if (ASR::is_a(*x.m_value)) { self().visit_FunctionCall(*ASR::down_cast(x.m_value)); ASR::Tuple_t* t = ASR::down_cast(tup_c->m_type); @@ -719,7 +736,7 @@ R"(#include src_tmp += indent + c_ds_api->get_deepcopy(t, val_name + ".element_" + std::to_string(i), src) + "\n"; } - src = src_tmp; + src = check_tmp_buffer() + src_tmp; return; } else { LCOMPILERS_ASSERT(false) @@ -751,42 +768,23 @@ R"(#include } else { src.clear(); } + src = check_tmp_buffer(); if( is_target_list && is_value_list ) { ASR::List_t* list_target = ASR::down_cast(ASRUtils::expr_type(x.m_target)); std::string list_dc_func = c_ds_api->get_list_deepcopy_func(list_target); - if( ASR::is_a(*x.m_value) ) { - src += value; - ASR::ListConstant_t *l_const = ASR::down_cast(x.m_value); - std::string var_name = const_var_names[get_hash((ASR::asr_t*)l_const)]; - src += indent + list_dc_func + "(&" + var_name + ", &" + target + ");\n\n"; - } else if (ASR::is_a(*x.m_value)) { + if (ASR::is_a(*x.m_value)) { src += indent + list_dc_func + "(" + value + ", &" + target + ");\n\n"; - } else if (ASR::is_a(*x.m_value)) { - src += value; - ASR::ListSection_t *l_sec = ASR::down_cast(x.m_value); - std::string var_name = const_var_names[get_hash((ASR::asr_t*)l_sec)]; - src += indent + list_dc_func + "(" + var_name + ", &" + target + ");\n\n"; } else { src += indent + list_dc_func + "(&" + value + ", &" + target + ");\n\n"; } } else if ( is_target_tup && is_value_tup ) { ASR::Tuple_t* tup_target = ASR::down_cast(ASRUtils::expr_type(x.m_target)); std::string dc_func = c_ds_api->get_tuple_deepcopy_func(tup_target); - if( ASR::is_a(*x.m_value) ) { - src += value; - src += indent + dc_func + "(" + const_name + ", &" + target + ");\n"; - } else { - src += indent + dc_func + "(" + value + ", &" + target + ");\n"; - } + src += indent + dc_func + "(" + value + ", &" + target + ");\n"; } else if ( is_target_dict && is_value_dict ) { ASR::Dict_t* d_target = ASR::down_cast(ASRUtils::expr_type(x.m_target)); std::string dc_func = c_ds_api->get_dict_deepcopy_func(d_target); - if( ASR::is_a(*x.m_value) ) { - src += value; - src += indent + dc_func + "(&" + const_name + ", &" + target + ");\n"; - } else { - src += indent + dc_func + "(&" + value + ", &" + target + ");\n"; - } + src += indent + dc_func + "(&" + value + ", &" + target + ");\n"; } else { if( is_c ) { std::string alloc = ""; @@ -906,16 +904,12 @@ R"(#include if( ASR::is_a(*t->m_type) ) { src_tmp += indent + var_name + ".data[" + std::to_string(i) +"] = NULL;\n"; } - if (ASR::is_a(*x.m_args[i]) || - ASR::is_a(*x.m_args[i])) { - src_tmp += src; - src = const_var_names[get_hash((ASR::asr_t*)x.m_args[i])]; - } src_tmp += indent + c_ds_api->get_deepcopy(t->m_type, src, var_name + ".data[" + std::to_string(i) +"]") + "\n"; } src_tmp += indent + var_name + ".current_end_point = " + std::to_string(x.n_args) + ";\n"; - src = src_tmp; + src = var_name; + tmp_buffer_src.push_back(src_tmp); } void visit_TupleConstant(const ASR::TupleConstant_t& x) { @@ -939,7 +933,8 @@ R"(#include src_tmp += indent + c_ds_api->get_deepcopy(t->m_type[i], src, var_name + ele) + "\n"; } src_tmp += indent + var_name + ".length" + " = " + std::to_string(x.n_elements) + ";\n"; - src = src_tmp; + src = var_name; + tmp_buffer_src.push_back(src_tmp); } void visit_DictConstant(const ASR::DictConstant_t& x) { @@ -967,21 +962,25 @@ R"(#include src_tmp += indent + dict_ins_func + "(&" + var_name + ", " +\ k + ", " + v + ");\n"; } - src = src_tmp; + src = var_name; + tmp_buffer_src.push_back(src_tmp); } void visit_TupleCompare(const ASR::TupleCompare_t& x) { ASR::ttype_t* type = ASRUtils::expr_type(x.m_left); std::string tup_cmp_func = c_ds_api->get_compare_func(type); + bracket_open++; self().visit_expr(*x.m_left); std::string left = std::move(src); self().visit_expr(*x.m_right); std::string right = std::move(src); + bracket_open--; std::string indent(indentation_level * indentation_spaces, ' '); src = tup_cmp_func + "(" + left + ", " + right + ")"; if (x.m_op == ASR::cmpopType::NotEq) { src = "!" + src; } + src = check_tmp_buffer() + src; } void visit_DictInsert(const ASR::DictInsert_t& x) { @@ -1016,33 +1015,38 @@ R"(#include ASR::ttype_t* t_ttype = ASRUtils::expr_type(x.m_a); ASR::List_t* t = ASR::down_cast(t_ttype); std::string list_append_func = c_ds_api->get_list_append_func(t); + bracket_open++; self().visit_expr(*x.m_a); std::string list_var = std::move(src); self().visit_expr(*x.m_ele); std::string element = std::move(src); + bracket_open--; std::string indent(indentation_level * indentation_spaces, ' '); - src = indent + list_append_func + "(&" + list_var + ", " + element + ");\n"; + src = check_tmp_buffer(); + src += indent + list_append_func + "(&" + list_var + ", " + element + ");\n"; } void visit_ListConcat(const ASR::ListConcat_t& x) { ASR::List_t* t = ASR::down_cast(x.m_type); std::string list_concat_func = c_ds_api->get_list_concat_func(t); + bracket_open++; self().visit_expr(*x.m_left); std::string left = std::move(src); if (!ASR::is_a(*x.m_left)) { left = "&" + left; } self().visit_expr(*x.m_right); + bracket_open--; std::string rig = std::move(src); if (!ASR::is_a(*x.m_right)) { rig = "&" + rig; } - std::string indent(indentation_level * indentation_spaces, ' '); - src = list_concat_func + "(" + left + ", " + rig + ")"; + src = check_tmp_buffer() + list_concat_func + "(" + left + ", " + rig + ")"; } void visit_ListSection(const ASR::ListSection_t& x) { std::string left, right, step, l_present, r_present; + bracket_open++; if (x.m_section.m_left) { self().visit_expr(*x.m_section.m_left); left = src; @@ -1066,7 +1070,7 @@ R"(#include step = "1"; } self().visit_expr(*x.m_a); - + bracket_open--; ASR::ttype_t* t_ttype = ASRUtils::expr_type(x.m_a); ASR::List_t* t = ASR::down_cast(t_ttype); std::string list_var = std::move(src); @@ -1081,80 +1085,69 @@ R"(#include tmp_src_gen += list_section_func + "(&" + list_var + ", " + left + ", " + right + ", " + step + ", " + l_present + ", " + r_present + ");\n"; const_var_names[get_hash((ASR::asr_t*)&x)] = var_name; - src = tmp_src_gen; + tmp_buffer_src.push_back(tmp_src_gen); + src = "* " + var_name; } void visit_ListClear(const ASR::ListClear_t& x) { ASR::ttype_t* t_ttype = ASRUtils::expr_type(x.m_a); ASR::List_t* t = ASR::down_cast(t_ttype); std::string list_clear_func = c_ds_api->get_list_clear_func(t); + bracket_open++; self().visit_expr(*x.m_a); + bracket_open--; std::string list_var = std::move(src); std::string indent(indentation_level * indentation_spaces, ' '); - src = indent + list_clear_func + "(&" + list_var + ");\n"; + src = check_tmp_buffer() + indent + list_clear_func + "(&" + list_var + ");\n"; } void visit_ListCompare(const ASR::ListCompare_t& x) { ASR::ttype_t* type = ASRUtils::expr_type(x.m_left); std::string list_cmp_func = c_ds_api->get_compare_func(type); + bracket_open++; self().visit_expr(*x.m_left); std::string left = std::move(src); self().visit_expr(*x.m_right); - std::string right = std::move(src), tmp_gen=""; + bracket_open--; + std::string right = std::move(src), tmp_gen= ""; std::string indent(indentation_level * indentation_spaces, ' '); - if (ASR::is_a(*x.m_left) ) { - tmp_gen += left; - ASR::ListConstant_t *l_const = ASR::down_cast(x.m_left); - left = const_var_names[get_hash((ASR::asr_t*)l_const)]; - } else if (ASR::is_a(*x.m_left)) { - tmp_gen += left; - ASR::ListSection_t *l_sec = ASR::down_cast(x.m_left); - left = "*" + const_var_names[get_hash((ASR::asr_t*)l_sec)]; - } - - if (ASR::is_a(*x.m_right) ) { - tmp_gen += right; - ASR::ListConstant_t *l_const = ASR::down_cast(x.m_right); - right = const_var_names[get_hash((ASR::asr_t*)l_const)]; - } else if (ASR::is_a(*x.m_right)) { - tmp_gen += right; - ASR::ListSection_t *l_sec = ASR::down_cast(x.m_right); - right = "*" + const_var_names[get_hash((ASR::asr_t*)l_sec)]; - } std::string val = list_cmp_func + "(" + left + ", " + right + ")"; if (x.m_op == ASR::cmpopType::NotEq) { val = "!" + val; } - src = val; - if (tmp_gen.size() > 0) { - tmp_src.push_back(tmp_gen); - } + src = check_tmp_buffer() + val; } void visit_ListInsert(const ASR::ListInsert_t& x) { ASR::ttype_t* t_ttype = ASRUtils::expr_type(x.m_a); ASR::List_t* t = ASR::down_cast(t_ttype); std::string list_insert_func = c_ds_api->get_list_insert_func(t); + bracket_open++; self().visit_expr(*x.m_a); std::string list_var = std::move(src); self().visit_expr(*x.m_ele); std::string element = std::move(src); self().visit_expr(*x.m_pos); + bracket_open--; std::string pos = std::move(src); std::string indent(indentation_level * indentation_spaces, ' '); - src = indent + list_insert_func + "(&" + list_var + ", " + pos + ", " + element + ");\n"; + src = check_tmp_buffer(); + src += indent + list_insert_func + "(&" + list_var + ", " + pos + ", " + element + ");\n"; } void visit_ListRemove(const ASR::ListRemove_t& x) { ASR::ttype_t* t_ttype = ASRUtils::expr_type(x.m_a); ASR::List_t* t = ASR::down_cast(t_ttype); std::string list_remove_func = c_ds_api->get_list_remove_func(t); + bracket_open++; self().visit_expr(*x.m_a); std::string list_var = std::move(src); self().visit_expr(*x.m_ele); + bracket_open--; std::string element = std::move(src); std::string indent(indentation_level * indentation_spaces, ' '); - src = indent + list_remove_func + "(&" + list_var + ", " + element + ");\n"; + src = check_tmp_buffer(); + src += indent + list_remove_func + "(&" + list_var + ", " + element + ");\n"; } void visit_ListLen(const ASR::ListLen_t& x) { @@ -1756,6 +1749,7 @@ R"(#include for (size_t i = 0; i < x.n_body; i++) { if (i > 0) out += indent + "else if ("; + bracket_open++; ASR::case_stmt_t* stmt = x.m_body[i]; if (stmt->type == ASR::case_stmtType::CaseStmt) { ASR::CaseStmt_t* case_stmt = ASR::down_cast(stmt); @@ -1766,6 +1760,7 @@ R"(#include out += var + " == " + src; } out += ") {\n"; + bracket_open--; indentation_level += 1; for (size_t j = 0; j < case_stmt->n_body; j++) { this->visit_stmt(*case_stmt->m_body[j]); @@ -1798,6 +1793,7 @@ R"(#include out += left + " <= " + var + " <= " + right; } out += ") {\n"; + bracket_open--; indentation_level += 1; for (size_t j = 0; j < case_stmt_range->n_body; j++) { this->visit_stmt(*case_stmt_range->m_body[j]); @@ -1817,18 +1813,21 @@ R"(#include out += indent + "}\n"; indentation_level -= 1; } - src = out; + src = check_tmp_buffer() + out; } void visit_WhileLoop(const ASR::WhileLoop_t &x) { std::string indent(indentation_level*indentation_spaces, ' '); + bracket_open++; std::string out = indent + "while ("; self().visit_expr(*x.m_test); out += src + ") {\n"; + bracket_open--; + out = check_tmp_buffer() + out; indentation_level += 1; for (size_t i=0; i current_body = ""; std::string indent(indentation_level*indentation_spaces, ' '); std::string out = indent + "if ("; + bracket_open++; self().visit_expr(*x.m_test); out += src + ") {\n"; + bracket_open--; + out = check_tmp_buffer() + out; indentation_level += 1; for (size_t i=0; i out += " else {\n"; for (size_t i=0; i