Skip to content

Commit

Permalink
Fixing symbolic pass according to error caught in LFortran (lcompiler…
Browse files Browse the repository at this point in the history
  • Loading branch information
anutosh491 committed Mar 2, 2024
1 parent 68774de commit 71c95ce
Showing 1 changed file with 51 additions and 21 deletions.
72 changes: 51 additions & 21 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
{handle_argument(al, loc, value_01), handle_argument(al, loc, value_02)},
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)));
}

static inline bool is_logical_intrinsic_symbolic(ASR::expr_t* expr) {
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(expr);
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
switch (static_cast<LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id)) {
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ:
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ:
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ:
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ:
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLogQ:
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSinQ:
return true;
default:
return false;
}
}
return true;
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -514,9 +533,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) {
process_intrinsic_function(x.base.base.loc, intrinsic_func, x.m_target);
} else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* function_call = process_attributes(x.base.base.loc, x.m_value);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
pass_result.push_back(al, stmt);
if (is_logical_intrinsic_symbolic(x.m_value)) {
ASR::expr_t* function_call = process_attributes(x.base.base.loc, x.m_value);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
pass_result.push_back(al, stmt);
}
}
} else if (ASR::is_a<ASR::Cast_t>(*x.m_value)) {
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(x.m_value);
Expand Down Expand Up @@ -676,18 +697,22 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*xx.m_test)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(xx.m_test);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test);
xx.m_test = function_call;
if (is_logical_intrinsic_symbolic(xx.m_test)) {
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test);
xx.m_test = function_call;
}
}
} else if (ASR::is_a<ASR::LogicalNot_t>(*xx.m_test)) {
ASR::LogicalNot_t* logical_not = ASR::down_cast<ASR::LogicalNot_t>(xx.m_test);
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*logical_not->m_arg)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(logical_not->m_arg);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, logical_not->m_arg);
ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call,
logical_not->m_type, logical_not->m_value));
xx.m_test = new_logical_not;
if (is_logical_intrinsic_symbolic(logical_not->m_arg)) {
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, logical_not->m_arg);
ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call,
logical_not->m_type, logical_not->m_value));
xx.m_test = new_logical_not;
}
}
}
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*xx.m_test)) {
Expand Down Expand Up @@ -784,8 +809,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
// Now create the FunctionCall node for basic_str
print_tmp.push_back(basic_str(x.base.base.loc, target));
} else if (ASR::is_a<ASR::Logical_t>(*ASRUtils::expr_type(val))) {
ASR::expr_t* function_call = process_attributes(x.base.base.loc, val);
print_tmp.push_back(function_call);
if (is_logical_intrinsic_symbolic(val)) {
ASR::expr_t* function_call = process_attributes(x.base.base.loc, val);
print_tmp.push_back(function_call);
}
}
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
Expand Down Expand Up @@ -926,14 +953,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* right_tmp = nullptr;
if (ASR::is_a<ASR::LogicalCompare_t>(*x.m_test)) {
ASR::LogicalCompare_t *l = ASR::down_cast<ASR::LogicalCompare_t>(x.m_test);
if (is_logical_intrinsic_symbolic(l->m_left) && is_logical_intrinsic_symbolic(l->m_right)) {
left_tmp = process_attributes(x.base.base.loc, l->m_left);
right_tmp = process_attributes(x.base.base.loc, l->m_right);
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp,
l->m_op, right_tmp, l->m_type, l->m_value));

left_tmp = process_attributes(x.base.base.loc, l->m_left);
right_tmp = process_attributes(x.base.base.loc, l->m_right);
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp,
l->m_op, right_tmp, l->m_type, l->m_value));

ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
}
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test)) {
ASR::SymbolicCompare_t* s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test);
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
Expand All @@ -949,9 +977,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_test)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_test);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* test = process_attributes(x.base.base.loc, x.m_test);
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
if (is_logical_intrinsic_symbolic(x.m_test)) {
ASR::expr_t* test = process_attributes(x.base.base.loc, x.m_test);
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
}
}
} else if (ASR::is_a<ASR::LogicalBinOp_t>(*x.m_test)) {
ASR::LogicalBinOp_t* binop = ASR::down_cast<ASR::LogicalBinOp_t>(x.m_test);
Expand Down

0 comments on commit 71c95ce

Please sign in to comment.