diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 96a1469765..fadea2b021 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -266,6 +266,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*expr)) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(expr); + int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; + switch (static_cast(intrinsic_id)) { + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ: + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ: + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLogQ: + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSinQ: + return true; + default: + return false; + } + } + return true; + } /********************************** Utils *********************************/ void visit_Function(const ASR::Function_t &x) { @@ -514,9 +533,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorm_type->type == ASR::ttypeType::SymbolicExpression) { process_intrinsic_function(x.base.base.loc, intrinsic_func, x.m_target); } else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { - ASR::expr_t* function_call = process_attributes(x.base.base.loc, x.m_value); - ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); - pass_result.push_back(al, stmt); + if (is_logical_intrinsic_symbolic(x.m_value)) { + ASR::expr_t* function_call = process_attributes(x.base.base.loc, x.m_value); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); + pass_result.push_back(al, stmt); + } } } else if (ASR::is_a(*x.m_value)) { ASR::Cast_t* cast_t = ASR::down_cast(x.m_value); @@ -676,18 +697,22 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*xx.m_test)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(xx.m_test); if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { - ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test); - xx.m_test = function_call; + if (is_logical_intrinsic_symbolic(xx.m_test)) { + ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test); + xx.m_test = function_call; + } } } else if (ASR::is_a(*xx.m_test)) { ASR::LogicalNot_t* logical_not = ASR::down_cast(xx.m_test); if (ASR::is_a(*logical_not->m_arg)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(logical_not->m_arg); if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { - ASR::expr_t* function_call = process_attributes(xx.base.base.loc, logical_not->m_arg); - ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call, - logical_not->m_type, logical_not->m_value)); - xx.m_test = new_logical_not; + if (is_logical_intrinsic_symbolic(logical_not->m_arg)) { + ASR::expr_t* function_call = process_attributes(xx.base.base.loc, logical_not->m_arg); + ASR::expr_t* new_logical_not = ASRUtils::EXPR(ASR::make_LogicalNot_t(al, xx.base.base.loc, function_call, + logical_not->m_type, logical_not->m_value)); + xx.m_test = new_logical_not; + } } } } else if (ASR::is_a(*xx.m_test)) { @@ -784,8 +809,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*ASRUtils::expr_type(val))) { - ASR::expr_t* function_call = process_attributes(x.base.base.loc, val); - print_tmp.push_back(function_call); + if (is_logical_intrinsic_symbolic(val)) { + ASR::expr_t* function_call = process_attributes(x.base.base.loc, val); + print_tmp.push_back(function_call); + } } } else if (ASR::is_a(*val)) { ASR::Cast_t* cast_t = ASR::down_cast(val); @@ -926,14 +953,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { ASR::LogicalCompare_t *l = ASR::down_cast(x.m_test); + if (is_logical_intrinsic_symbolic(l->m_left) && is_logical_intrinsic_symbolic(l->m_right)) { + left_tmp = process_attributes(x.base.base.loc, l->m_left); + right_tmp = process_attributes(x.base.base.loc, l->m_right); + ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp, + l->m_op, right_tmp, l->m_type, l->m_value)); - left_tmp = process_attributes(x.base.base.loc, l->m_left); - right_tmp = process_attributes(x.base.base.loc, l->m_right); - ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp, - l->m_op, right_tmp, l->m_type, l->m_value)); - - ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); - pass_result.push_back(al, assert_stmt); + ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); + pass_result.push_back(al, assert_stmt); + } } else if (ASR::is_a(*x.m_test)) { ASR::SymbolicCompare_t* s = ASR::down_cast(x.m_test); if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { @@ -949,9 +977,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(x.m_test); if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { - ASR::expr_t* test = process_attributes(x.base.base.loc, x.m_test); - ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); - pass_result.push_back(al, assert_stmt); + if (is_logical_intrinsic_symbolic(x.m_test)) { + ASR::expr_t* test = process_attributes(x.base.base.loc, x.m_test); + ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); + pass_result.push_back(al, assert_stmt); + } } } else if (ASR::is_a(*x.m_test)) { ASR::LogicalBinOp_t* binop = ASR::down_cast(x.m_test);