Skip to content

Commit

Permalink
[ASR Pass] Symbolic: Use a function to create basic_new_stack BindC…
Browse files Browse the repository at this point in the history
… Function
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
1 parent 7899add commit dbb001f
Showing 1 changed file with 45 additions and 43 deletions.
88 changes: 45 additions & 43 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,48 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
std::set<ASR::symbol_t*> symbolic_vars_to_omit;
SymEngine_Stack symengine_stack;

/********************************** Utils *********************************/
ASR::stmt_t *basic_new_stack(const Location &loc, ASR::expr_t *x) {
std::string fn_name = "basic_new_stack";
symbolic_dependencies.push_back(fn_name);
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
ASR::symbol_t* basic_new_stack_sym = current_scope->resolve_symbol(fn_name);
if ( !basic_new_stack_sym ) {
std::string header = "symengine/cwrapper.h";
SymbolTable *fn_symtab = al.make_new<SymbolTable>(current_scope->parent);

Vec<ASR::expr_t*> args;
{
args.reserve(al, 1);
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg)));
}

Vec<ASR::stmt_t *> body; body.reserve(al, 1);
Vec<char *> dependencies; dependencies.reserve(al, 1);
basic_new_stack_sym = ASR::down_cast<ASR::symbol_t>(
ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name),
dependencies.p, dependencies.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, fn_name), false, false,
false, false, false, nullptr, 0, false, false, false, s2c(al, header)));
current_scope->parent->add_symbol(fn_name, basic_new_stack_sym);
}

Vec<ASR::call_arg_t> call_args; call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = x;
call_args.push_back(al, call_arg);
return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_new_stack_sym,
basic_new_stack_sym, call_args.p, call_args.n, nullptr));
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
// FIXME: this is a hack, we need to pass in a non-const `x`,
// which requires to generate a TransformVisitor.
Expand Down Expand Up @@ -143,39 +185,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

current_scope->add_symbol(s2c(al, placeholder), sym2);

std::string new_name = "basic_new_stack";
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
{
args.reserve(al, 1);
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
}

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char *> dep;
dep.reserve(al, 1);

ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
module_scope->add_symbol(new_name, new_symbol);
}

new_name = "basic_free_stack";
std::string new_name = "basic_free_stack";
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
Expand Down Expand Up @@ -228,21 +239,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node,
type1, nullptr));

// statement 4
ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol("basic_new_stack");
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = xx.base.base.loc;
call_arg.m_value = target2;
call_args.push_back(al, call_arg);

// defining the assignment statement
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr));
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr));
ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym,
basic_new_stack_sym, call_args.p, call_args.n, nullptr));
// statement 4
ASR::stmt_t* stmt4 = basic_new_stack(x.base.base.loc, target2);

pass_result.push_back(al, stmt1);
pass_result.push_back(al, stmt2);
Expand Down

0 comments on commit dbb001f

Please sign in to comment.