Skip to content

Commit

Permalink
[ASR Pass] Symbolic: Simplify process_attributes arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
1 parent f6d0bd6 commit 6a7b2cd
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,8 +1003,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr,
SymbolTable* module_scope) {
ASR::expr_t* process_attributes(const Location &loc, ASR::expr_t* expr) {
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(expr);
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
Expand All @@ -1031,7 +1030,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}

void visit_Assignment(const ASR::Assignment_t &x) {
SymbolTable* module_scope = current_scope->parent;
if (ASR::is_a<ASR::Var_t>(*x.m_value) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(x.m_value))) {
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(x.m_value)->m_v;
if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return;
Expand All @@ -1043,7 +1041,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) {
process_intrinsic_function(x.base.base.loc, intrinsic_func, x.m_target);
} else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, x.m_value, module_scope);
ASR::expr_t* function_call = process_attributes(x.base.base.loc, x.m_value);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
pass_result.push_back(al, stmt);
}
Expand Down Expand Up @@ -1129,11 +1127,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::If_t& xx = const_cast<ASR::If_t&>(x);
transform_stmts(xx.m_body, xx.n_body);
transform_stmts(xx.m_orelse, xx.n_orelse);
SymbolTable* module_scope = current_scope->parent;
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*xx.m_test)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(xx.m_test);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* function_call = process_attributes(al, xx.base.base.loc, xx.m_test, module_scope);
ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test);
xx.m_test = function_call;
}
}
Expand Down Expand Up @@ -1190,7 +1187,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

void visit_Print(const ASR::Print_t &x) {
std::vector<ASR::expr_t*> print_tmp;
SymbolTable* module_scope = current_scope->parent;
for (size_t i=0; i<x.n_values; i++) {
ASR::expr_t* val = x.m_values[i];
if (ASR::is_a<ASR::Var_t>(*val) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(val))) {
Expand Down Expand Up @@ -1220,7 +1216,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
// Now create the FunctionCall node for basic_str
print_tmp.push_back(basic_str(x.base.base.loc, target));
} else if (ASR::is_a<ASR::Logical_t>(*ASRUtils::expr_type(val))) {
ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, val, module_scope);
ASR::expr_t* function_call = process_attributes(x.base.base.loc, val);
print_tmp.push_back(function_call);
}
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
Expand Down Expand Up @@ -1358,14 +1354,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}

void visit_Assert(const ASR::Assert_t &x) {
SymbolTable* module_scope = current_scope->parent;
ASR::expr_t* left_tmp = nullptr;
ASR::expr_t* right_tmp = nullptr;
if (ASR::is_a<ASR::LogicalCompare_t>(*x.m_test)) {
ASR::LogicalCompare_t *l = ASR::down_cast<ASR::LogicalCompare_t>(x.m_test);

left_tmp = process_attributes(al, x.base.base.loc, l->m_left, module_scope);
right_tmp = process_attributes(al, x.base.base.loc, l->m_right, module_scope);
left_tmp = process_attributes(x.base.base.loc, l->m_left);
right_tmp = process_attributes(x.base.base.loc, l->m_right);
ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp,
l->m_op, right_tmp, l->m_type, l->m_value));

Expand All @@ -1386,7 +1381,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*x.m_test)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(x.m_test);
if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) {
ASR::expr_t* test = process_attributes(al, x.base.base.loc, x.m_test, module_scope);
ASR::expr_t* test = process_attributes(x.base.base.loc, x.m_test);
ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg));
pass_result.push_back(al, assert_stmt);
}
Expand Down

0 comments on commit 6a7b2cd

Please sign in to comment.