Skip to content

Commit

Permalink
[ASR Pass] Symbolic: Simplify basic_str to return FunctionCall
Browse files Browse the repository at this point in the history
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
1 parent f497d07 commit 0aa4435
Showing 1 changed file with 61 additions and 110 deletions.
171 changes: 61 additions & 110 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,51 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_assign_sym,
basic_assign_sym, call_args.p, call_args.n, nullptr));
}

ASR::expr_t* basic_str(const Location& loc, ASR::expr_t *x) {
std::string fn_name = "basic_str";
symbolic_dependencies.push_back(fn_name);
ASR::symbol_t *basic_str_sym = current_scope->resolve_symbol(fn_name);
if ( !basic_str_sym ) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(current_scope->parent);

Vec<ASR::expr_t*> args; args.reserve(al, 1);
char *return_var_name = s2c(al, "_lpython_return_variable");
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, return_var_name, nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default,
ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(return_var_name, arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default,
ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body; body.reserve(al, 1);
Vec<char*> dep; dep.reserve(al, 1);
ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1));
basic_str_sym = ASR::down_cast<ASR::symbol_t>(
ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name),
dep.p, dep.n, args.p, args.n, body.p, body.n, return_var,
ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header)));
current_scope->parent->add_symbol(s2c(al, fn_name), basic_str_sym);
}
Vec<ASR::call_arg_t> call_args; call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = x;
call_args.push_back(al, call_arg);
return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr));
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -826,45 +871,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

ASR::symbol_t* declare_basic_str_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_str";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
return_var, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_integer_set_si_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "integer_set_si";
symbolic_dependencies.push_back(name);
Expand Down Expand Up @@ -1406,23 +1412,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
if (ASR::is_a<ASR::Var_t>(*val) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(val))) {
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(val)->m_v;
if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return;
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);

// Extract the symbol from value (Var)
ASR::symbol_t* var_sym = ASR::down_cast<ASR::Var_t>(val)->m_v;
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));

// Now create the FunctionCall node for basic_str
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
print_tmp.push_back(function_call);
print_tmp.push_back(basic_str(x.base.base.loc, val));
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(val);
if (ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(val))) {
Expand All @@ -1444,17 +1434,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);

// Now create the FunctionCall node for basic_str
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
print_tmp.push_back(function_call);
print_tmp.push_back(basic_str(x.base.base.loc, target));
} else if (ASR::is_a<ASR::Logical_t>(*ASRUtils::expr_type(val))) {
ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, val, module_scope);
print_tmp.push_back(function_call);
Expand All @@ -1467,17 +1447,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));

// Now create the FunctionCall node for basic_str
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
print_tmp.push_back(function_call);
print_tmp.push_back(basic_str(x.base.base.loc, target));
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*val)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(val);
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
Expand Down Expand Up @@ -1507,20 +1477,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
} else if (ASR::is_a<ASR::ListItem_t>(*val)) {
ASR::ListItem_t* list_item = ASR::down_cast<ASR::ListItem_t>(val);
if (list_item->m_type->type == ASR::ttypeType::SymbolicExpression) {
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = x.base.base.loc;
call_arg.m_value = ASRUtils::EXPR(ASR::make_ListItem_t(al, x.base.base.loc, list_item->m_a,
list_item->m_pos, CPtr_type, nullptr));
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
print_tmp.push_back(function_call);
ASR::expr_t *value = ASRUtils::EXPR(ASR::make_ListItem_t(al,
x.base.base.loc, list_item->m_a, list_item->m_pos,
ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc)), nullptr));
print_tmp.push_back(basic_str(x.base.base.loc, value));
}
} else {
print_tmp.push_back(x.m_values[i]);
Expand Down Expand Up @@ -1623,8 +1583,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

ASR::expr_t* process_with_basic_str(Allocator &al, const Location &loc, const ASR::expr_t* expr,
ASR::symbol_t* basic_str_sym) {
ASR::expr_t* process_with_basic_str(const Location &loc, const ASR::expr_t *expr) {
ASR::symbol_t *var_sym = nullptr;
if (ASR::is_a<ASR::Var_t>(*expr)) {
var_sym = ASR::down_cast<ASR::Var_t>(expr)->m_v;
Expand All @@ -1636,20 +1595,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(expr);
this->visit_Cast(*cast_t);
var_sym = current_scope->get_symbol(symengine_stack.pop());
} else {
LCOMPILERS_ASSERT(false);
}

ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, loc, var_sym));
// Now create the FunctionCall node for basic_str
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr));
return function_call;
// Now create the FunctionCall node for basic_str and return
return basic_str(loc, target);
}

void visit_Assert(const ASR::Assert_t &x) {
Expand Down Expand Up @@ -1703,16 +1655,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
} else if (ASR::is_a<ASR::LogicalBinOp_t>(*x.m_test)) {
ASR::LogicalBinOp_t* binop = ASR::down_cast<ASR::LogicalBinOp_t>(x.m_test);
if (ASR::is_a<ASR::SymbolicCompare_t>(*binop->m_left) && ASR::is_a<ASR::SymbolicCompare_t>(*binop->m_right)) {
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
ASR::SymbolicCompare_t *s1 = ASR::down_cast<ASR::SymbolicCompare_t>(binop->m_left);
left_tmp = process_with_basic_str(al, x.base.base.loc, s1->m_left, basic_str_sym);
right_tmp = process_with_basic_str(al, x.base.base.loc, s1->m_right, basic_str_sym);
left_tmp = process_with_basic_str(x.base.base.loc, s1->m_left);
right_tmp = process_with_basic_str(x.base.base.loc, s1->m_right);
ASR::expr_t* test1 = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp,
s1->m_op, right_tmp, s1->m_type, s1->m_value));

ASR::SymbolicCompare_t *s2 = ASR::down_cast<ASR::SymbolicCompare_t>(binop->m_right);
left_tmp = process_with_basic_str(al, x.base.base.loc, s2->m_left, basic_str_sym);
right_tmp = process_with_basic_str(al, x.base.base.loc, s2->m_right, basic_str_sym);
left_tmp = process_with_basic_str(x.base.base.loc, s2->m_left);
right_tmp = process_with_basic_str(x.base.base.loc, s2->m_right);
ASR::expr_t* test2 = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp,
s2->m_op, right_tmp, s2->m_type, s2->m_value));

Expand Down

0 comments on commit 0aa4435

Please sign in to comment.