Skip to content

Commit

Permalink
Merge pull request lcompilers#746 from czgdp1807/arr_fix
Browse files Browse the repository at this point in the history
Fixing ArrayItem/ArraySection nodes in ASR
  • Loading branch information
czgdp1807 committed Jul 7, 2022
2 parents 5555359 + c9bc740 commit 0cd2460
Show file tree
Hide file tree
Showing 28 changed files with 135 additions and 81 deletions.
4 changes: 2 additions & 2 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ expr
| DictConstant(expr* keys, expr* values, ttype type)
| DictLen(expr arg, ttype type, expr? value)
| Var(symbol v)
| ArrayItem(symbol v, array_index* args, ttype type, expr? value)
| ArraySection(symbol v, array_index* args, ttype type, expr? value)
| ArrayItem(expr v, array_index* args, ttype type, expr? value)
| ArraySection(expr v, array_index* args, ttype type, expr? value)
| ArraySize(expr v, expr? dim, ttype type, expr? value)
| ArrayBound(expr v, expr? dim, ttype type, arraybound bound,
expr? value)
Expand Down
15 changes: 14 additions & 1 deletion src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,20 @@ static inline ASR::ttype_t* duplicate_type(Allocator& al, const ASR::ttype_t* t,
tnew->m_kind, tnew->m_len, tnew->m_len_expr,
dimsp, dimsn));
}
default : throw LFortranException("Not implemented");
case ASR::ttypeType::Derived: {
ASR::Derived_t* tnew = ASR::down_cast<ASR::Derived_t>(t);
ASR::dimension_t* dimsp = dims ? dims->p : tnew->m_dims;
size_t dimsn = dims ? dims->n : tnew->n_dims;
return ASRUtils::TYPE(ASR::make_Derived_t(al, t->base.loc,
tnew->m_derived_type, dimsp, dimsn));
}
case ASR::ttypeType::Pointer: {
ASR::Pointer_t* ptr = ASR::down_cast<ASR::Pointer_t>(t);
ASR::ttype_t* dup_type = duplicate_type(al, ptr->m_type, dims);
return ASRUtils::TYPE(ASR::make_Pointer_t(al, ptr->base.base.loc,
dup_type));
}
default : throw LFortranException("Not implemented " + std::to_string(t->type));
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,7 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>

template <typename T>
void visit_ArrayItemSection(const T &x) {
require(symtab_in_scope(current_symtab, x.m_v),
"ArrayItem/ArraySection::m_v cannot point outside of its symbol table");
visit_expr(*x.m_v);
for (size_t i=0; i<x.n_args; i++) {
visit_array_index(x.m_args[i]);
}
Expand Down
18 changes: 4 additions & 14 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ R"(#include <stdio.h>
sv->m_intent == ASRUtils::intent_inout) &&
is_c && ASRUtils::is_array(sv->m_type) &&
ASRUtils::is_pointer(sv->m_type)) {
src = "*" + std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name);
src = "(*" + std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name) + ")";
} else {
src = std::string(ASR::down_cast<ASR::Variable_t>(s)->m_name);
}
Expand All @@ -600,20 +600,10 @@ R"(#include <stdio.h>
}

void visit_ArrayItem(const ASR::ArrayItem_t &x) {
const ASR::symbol_t *s = ASRUtils::symbol_get_past_external(x.m_v);
ASR::Variable_t* sv = ASR::down_cast<ASR::Variable_t>(s);
std::string prefix = "";
// if( ASR::is_a<ASR::Derived_t>(*sv->m_type) ) {
// prefix = "&";
// }
std::string out = std::string(sv->m_name);
if( (sv->m_intent == ASRUtils::intent_in ||
sv->m_intent == ASRUtils::intent_inout) &&
is_c && ASRUtils::is_pointer(sv->m_type) ) {
out = "(*" + out + ")";
}
this->visit_expr(*x.m_v);
std::string out = src;
ASR::dimension_t* m_dims;
ASRUtils::extract_dimensions_from_ttype(sv->m_type, m_dims);
ASRUtils::extract_dimensions_from_ttype(ASRUtils::expr_type(x.m_v), m_dims);
out += "[";
for (size_t i=0; i<x.n_args; i++) {
if (x.m_args[i].m_right) {
Expand Down
59 changes: 39 additions & 20 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1154,16 +1154,31 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr_wrapper(x.m_value, true);
return;
}
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(x.m_v);
if( ASR::is_a<ASR::Derived_t>(*v->m_type) ) {
ASR::Derived_t* der_type = ASR::down_cast<ASR::Derived_t>(v->m_type);
der_type_name = ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(der_type->m_derived_type));
llvm::Value* array = nullptr;
if( ASR::is_a<ASR::Var_t>(*x.m_v) ) {
ASR::Variable_t *v = ASRUtils::EXPR2VAR(x.m_v);
if( ASR::is_a<ASR::Derived_t>(*v->m_type) ) {
ASR::Derived_t* der_type = ASR::down_cast<ASR::Derived_t>(v->m_type);
der_type_name = ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(der_type->m_derived_type));
}
uint32_t v_h = get_hash((ASR::asr_t*)v);
LFORTRAN_ASSERT(llvm_symtab.find(v_h) != llvm_symtab.end());
array = llvm_symtab[v_h];
} else {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_v);
if( ASR::is_a<ASR::Derived_t>(*ASRUtils::expr_type(x.m_v)) ) {
ASR::Derived_t* der_type = ASR::down_cast<ASR::Derived_t>(ASRUtils::expr_type(x.m_v));
der_type_name = ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(der_type->m_derived_type));
}
ptr_loads = ptr_loads_copy;
array = tmp;
}
uint32_t v_h = get_hash((ASR::asr_t*)v);
LFORTRAN_ASSERT(llvm_symtab.find(v_h) != llvm_symtab.end());
llvm::Value* array = llvm_symtab[v_h];
if (is_a<ASR::Character_t>(*x.m_type)
&& ASR::down_cast<ASR::Character_t>(x.m_type)->n_dims == 0) {
ASR::dimension_t* m_dims;
int n_dims = ASRUtils::extract_dimensions_from_ttype(
ASRUtils::expr_type(x.m_v), m_dims);
if (is_a<ASR::Character_t>(*x.m_type) && n_dims == 0) {
// String indexing:
if (x.n_args != 1) {
throw CodeGenError("Only string(a) supported for now.", x.base.base.loc);
Expand Down Expand Up @@ -1194,7 +1209,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
ptr_loads = ptr_loads_copy;
indices.push_back(tmp);
}
if (v->m_type->type == ASR::ttypeType::Pointer) {
if (ASRUtils::expr_type(x.m_v)->type == ASR::ttypeType::Pointer) {
array = builder->CreateLoad(array);
}
tmp = arr_descr->get_single_element(array, indices, x.n_args);
Expand All @@ -1206,12 +1221,16 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
this->visit_expr_wrapper(x.m_value, true);
return;
}
ASR::Variable_t *v = ASR::down_cast<ASR::Variable_t>(x.m_v);
uint32_t v_h = get_hash((ASR::asr_t*)v);
LFORTRAN_ASSERT(llvm_symtab.find(v_h) != llvm_symtab.end());
llvm::Value* array = llvm_symtab[v_h];
LFORTRAN_ASSERT(ASR::is_a<ASR::Character_t>(*x.m_type) &&
ASR::down_cast<ASR::Character_t>(x.m_type)->n_dims == 0);
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_v);
ptr_loads = ptr_loads_copy;
llvm::Value* array = tmp;
ASR::dimension_t* m_dims;
int n_dims = ASRUtils::extract_dimensions_from_ttype(
ASRUtils::expr_type(x.m_v), m_dims);
LFORTRAN_ASSERT(ASR::is_a<ASR::Character_t>(*ASRUtils::expr_type(x.m_v)) &&
n_dims == 0);
// String indexing:
if (x.n_args == 1) {
throw CodeGenError("Only string(a:b) supported for now.", x.base.base.loc);
Expand Down Expand Up @@ -2783,8 +2802,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
target = tmp;
if (is_a<ASR::ArrayItem_t>(*x.m_target)) {
ASR::ArrayItem_t *asr_target0 = ASR::down_cast<ASR::ArrayItem_t>(x.m_target);
if (is_a<ASR::Variable_t>(*asr_target0->m_v)) {
ASR::Variable_t *asr_target = ASR::down_cast<ASR::Variable_t>(asr_target0->m_v);
if (is_a<ASR::Var_t>(*asr_target0->m_v)) {
ASR::Variable_t *asr_target = ASRUtils::EXPR2VAR(asr_target0->m_v);
if ( is_a<ASR::Character_t>(*asr_target->m_type) ) {
ASR::Character_t *t = ASR::down_cast<ASR::Character_t>(asr_target->m_type);
if (t->n_dims == 0) {
Expand All @@ -2795,8 +2814,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
} else if (is_a<ASR::ArraySection_t>(*x.m_target)) {
ASR::ArraySection_t *asr_target0 = ASR::down_cast<ASR::ArraySection_t>(x.m_target);
if (is_a<ASR::Variable_t>(*asr_target0->m_v)) {
ASR::Variable_t *asr_target = ASR::down_cast<ASR::Variable_t>(asr_target0->m_v);
if (is_a<ASR::Var_t>(*asr_target0->m_v)) {
ASR::Variable_t *asr_target = ASRUtils::EXPR2VAR(asr_target0->m_v);
if ( is_a<ASR::Character_t>(*asr_target->m_type) ) {
ASR::Character_t *t = ASR::down_cast<ASR::Character_t>(asr_target->m_type);
if (t->n_dims == 0) {
Expand Down
4 changes: 2 additions & 2 deletions src/libasr/pass/arr_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class ArrSliceVisitor : public PassUtils::PassVisitor<ArrSliceVisitor>

void visit_ArraySection(const ASR::ArraySection_t& x) {
if( create_slice_var ) {
ASR::expr_t* x_arr_var = LFortran::ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, x.m_v));
ASR::expr_t* x_arr_var = x.m_v;
Str new_name_str;
new_name_str.from_str(al, "~" + std::to_string(slice_counter) + "_slice");
slice_counter += 1;
Expand Down Expand Up @@ -187,7 +187,7 @@ class ArrSliceVisitor : public PassUtils::PassVisitor<ArrSliceVisitor>
doloop_body.reserve(al, 1);
if( doloop == nullptr ) {
ASR::expr_t* target_ref = PassUtils::create_array_ref(slice_sym, idx_vars_target, al, x.base.base.loc, x.m_type);
ASR::expr_t* value_ref = PassUtils::create_array_ref(x.m_v, idx_vars_value, al, x.base.base.loc, x.m_type);
ASR::expr_t* value_ref = PassUtils::create_array_ref(x.m_v, idx_vars_value, al);
ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, target_ref, value_ref, nullptr));
doloop_body.push_back(al, assign_stmt);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/pass/array_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class ArrayOpVisitor : public PassUtils::PassVisitor<ArrayOpVisitor>
this->visit_expr(*(x.m_value));
} else if( ASR::is_a<ASR::ArraySection_t>(*x.m_target) ) {
ASR::ArraySection_t* array_ref = ASR::down_cast<ASR::ArraySection_t>(x.m_target);
result_var = LFortran::ASRUtils::EXPR(ASR::make_Var_t(al, x.m_target->base.loc, array_ref->m_v));
result_var = array_ref->m_v;
result_lbound.reserve(al, array_ref->n_args);
result_ubound.reserve(al, array_ref->n_args);
result_inc.reserve(al, array_ref->n_args);
Expand Down
23 changes: 16 additions & 7 deletions src/libasr/pass/implied_do_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor<ImpliedDoLoopVisitor>
head.loc = head.m_v->base.loc;
Vec<ASR::stmt_t*> doloop_body;
doloop_body.reserve(al, 1);
ASR::symbol_t* arr = arr_var->m_v;
ASR::ttype_t *_type = LFortran::ASRUtils::expr_type(idoloop->m_start);
ASR::expr_t* const_1 = LFortran::ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, arr_var->base.base.loc, 1, _type));
ASR::expr_t *const_n, *offset, *num_grps, *grp_start;
Expand All @@ -127,9 +126,14 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor<ImpliedDoLoopVisitor>
ai.m_step = nullptr;
args.reserve(al, 1);
args.push_back(al, ai);
ASR::expr_t* array_ref = LFortran::ASRUtils::EXPR(ASR::make_ArrayItem_t(al, arr_var->base.base.loc, arr,
args.p, args.size(),
LFortran::ASRUtils::expr_type(LFortran::ASRUtils::EXPR((ASR::asr_t*)arr_var)), nullptr));
ASR::ttype_t* array_ref_type = ASRUtils::expr_type(ASRUtils::EXPR((ASR::asr_t*)arr_var));
Vec<ASR::dimension_t> empty_dims;
empty_dims.reserve(al, 1);
array_ref_type = ASRUtils::duplicate_type(al, array_ref_type, &empty_dims);
ASR::expr_t* array_ref = LFortran::ASRUtils::EXPR(ASR::make_ArrayItem_t(al, arr_var->base.base.loc,
ASRUtils::EXPR((ASR::asr_t*)arr_var),
args.p, args.size(),
array_ref_type, nullptr));
if( idoloop->m_values[i]->type == ASR::exprType::ImpliedDoLoop ) {
throw LFortranException("Pass for nested ImpliedDoLoop nodes isn't implemented yet."); // idoloop->m_values[i]->base.loc
}
Expand Down Expand Up @@ -191,9 +195,14 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor<ImpliedDoLoopVisitor>
ai.m_step = nullptr;
args.reserve(al, 1);
args.push_back(al, ai);
ASR::expr_t* array_ref = LFortran::ASRUtils::EXPR(ASR::make_ArrayItem_t(al, arr_var->base.base.loc, arr_var->m_v,
args.p, args.size(),
LFortran::ASRUtils::expr_type(LFortran::ASRUtils::EXPR((ASR::asr_t*)arr_var)), nullptr));
ASR::ttype_t* array_ref_type = ASRUtils::expr_type(ASRUtils::EXPR((ASR::asr_t*)arr_var));
Vec<ASR::dimension_t> empty_dims;
empty_dims.reserve(al, 1);
array_ref_type = ASRUtils::duplicate_type(al, array_ref_type, &empty_dims);
ASR::expr_t* array_ref = LFortran::ASRUtils::EXPR(ASR::make_ArrayItem_t(al, arr_var->base.base.loc,
ASRUtils::EXPR((ASR::asr_t*)arr_var),
args.p, args.size(),
array_ref_type, nullptr));
ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, array_ref, arr_init->m_args[k], nullptr));
pass_result.push_back(al, assign_stmt);
ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, arr_var->base.base.loc, idx_var, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(idx_var), nullptr));
Expand Down
8 changes: 4 additions & 4 deletions src/libasr/pass/loop_vectorise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class LoopVectoriseVisitor : public PassUtils::SkipOptimizationSubroutineVisitor
ASR::is_a<ASR::WhileLoop_t>(x));
}

bool is_vector_copy(ASR::stmt_t* x, Vec<ASR::symbol_t*>& arrays) {
bool is_vector_copy(ASR::stmt_t* x, Vec<ASR::expr_t*>& arrays) {
if( !ASR::is_a<ASR::Assignment_t>(*x) ) {
return false;
}
Expand Down Expand Up @@ -100,11 +100,11 @@ class LoopVectoriseVisitor : public PassUtils::SkipOptimizationSubroutineVisitor
ASR::expr_t*& vector_length,
Vec<ASR::stmt_t*>& vectorised_loop_body) {
LFORTRAN_ASSERT(vectorised_loop_body.reserve_called);
Vec<ASR::symbol_t*> arrays;
Vec<ASR::expr_t*> arrays;
arrays.reserve(al, 2);
if( is_vector_copy(loop_stmt, arrays) ) {
ASR::symbol_t *target_sym = arrays[0], *value_sym = arrays[1];
ASR::ttype_t* target_type = ASRUtils::symbol_type(target_sym);
ASR::expr_t *target_sym = arrays[0], *value_sym = arrays[1];
ASR::ttype_t* target_type = ASRUtils::expr_type(target_sym);
int64_t vector_length_int = get_vector_length(target_type);
vector_length = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loop_stmt->base.loc,
vector_length_int, ASRUtils::expr_type(index)));
Expand Down
Loading

0 comments on commit 0cd2460

Please sign in to comment.