Skip to content

Commit

Permalink
[ASR Pass] Handle ArraySection and SIMDArray BinOp
Browse files Browse the repository at this point in the history
  • Loading branch information
Thirumalai-Shaktivel committed Nov 17, 2023
1 parent 680f9b5 commit a6a2807
Showing 1 changed file with 45 additions and 6 deletions.
51 changes: 45 additions & 6 deletions src/libasr/pass/array_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,16 +716,25 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {
op_n_dims = x_dims.size();
}

ASR::ttype_t* x_m_type = ASRUtils::TYPE(ASR::make_Pointer_t(al, loc,
ASRUtils::type_get_past_allocatable(ASRUtils::duplicate_type(al,
ASRUtils::type_get_past_pointer(x->m_type), &empty_dims))));

ASR::ttype_t* x_m_type;
if (op_expr && ASRUtils::is_simd_array(op_expr)) {
x_m_type = ASRUtils::expr_type(op_expr);
} else {
x_m_type = ASRUtils::TYPE(ASR::make_Pointer_t(al, loc,
ASRUtils::type_get_past_allocatable(ASRUtils::duplicate_type(al,
ASRUtils::type_get_past_pointer(x->m_type), &empty_dims))));
}
ASR::expr_t* array_section_pointer = PassUtils::create_var(
result_counter, "_array_section_pointer_", loc,
x_m_type, al, current_scope);
result_counter += 1;
pass_result.push_back(al, ASRUtils::STMT(ASRUtils::make_Associate_t_util(
al, loc, array_section_pointer, *current_expr)));
if (op_expr && ASRUtils::is_simd_array(op_expr)) {
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(
al, loc, array_section_pointer, *current_expr, nullptr)));
} else {
pass_result.push_back(al, ASRUtils::STMT(ASRUtils::make_Associate_t_util(
al, loc, array_section_pointer, *current_expr)));
}
*current_expr = array_section_pointer;

// Might get used in other replace_* methods as well.
Expand All @@ -740,6 +749,33 @@ class ReplaceArrayOp: public ASR::BaseExprReplacer<ReplaceArrayOp> {

template <typename T>
void replace_ArrayOpCommon(T* x, std::string res_prefix) {
bool is_left_simd = ASRUtils::is_simd_array(x->m_left);
bool is_right_simd = ASRUtils::is_simd_array(x->m_right);
if ( is_left_simd && is_right_simd ) {
return;
} else if ( ( is_left_simd && !is_right_simd) ||
(!is_left_simd && is_right_simd) ) {
ASR::expr_t** current_expr_copy = current_expr;
ASR::expr_t* op_expr_copy = op_expr;
if (is_left_simd) {
// Replace ArraySection, case: a = a + b(:4)
if (ASR::is_a<ASR::ArraySection_t>(*x->m_right)) {
current_expr = &(x->m_right);
op_expr = x->m_left;
this->replace_expr(x->m_right);
}
} else {
// Replace ArraySection, case: a = b(:4) + a
if (ASR::is_a<ASR::ArraySection_t>(*x->m_left)) {
current_expr = &(x->m_left);
op_expr = x->m_right;
this->replace_expr(x->m_left);
}
}
current_expr = current_expr_copy;
op_expr = op_expr_copy;
return;
}
const Location& loc = x->base.base.loc;
bool current_status = use_custom_loop_params;
use_custom_loop_params = false;
Expand Down Expand Up @@ -1587,6 +1623,9 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit

void visit_Assignment(const ASR::Assignment_t &x) {
if (ASRUtils::is_simd_array(x.m_target)) {
if (!ASR::is_a<ASR::ArrayPhysicalCast_t>(*x.m_value)) {
this->visit_expr(*x.m_value);
}
return;
}
if( (ASR::is_a<ASR::Pointer_t>(*ASRUtils::expr_type(x.m_target)) &&
Expand Down

0 comments on commit a6a2807

Please sign in to comment.