Skip to content

Commit

Permalink
[ASR Pass] Symbolic: Simplify vecbasic_get to return SubroutineCall
Browse files Browse the repository at this point in the history
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
1 parent 93546df commit 799932e
Showing 1 changed file with 54 additions and 61 deletions.
115 changes: 54 additions & 61 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,59 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
vecbasic_new_sym, vecbasic_new_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), nullptr, nullptr));
}

ASR::stmt_t* vecbasic_get(const Location& loc, ASR::expr_t *x, ASR::expr_t *y, ASR::expr_t *z) {
std::string name = "vecbasic_get";
symbolic_dependencies.push_back(name);
ASR::ttype_t *cptr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
ASR::symbol_t *vecbasic_get_sym = current_scope->resolve_symbol(name);
if ( !vecbasic_get_sym ) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(current_scope->parent);

Vec<ASR::expr_t*> args; args.reserve(al, 3);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, cptr_type,
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_Integer_t(al, loc, 4))),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, cptr_type,
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "z"), arg3);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));

Vec<ASR::stmt_t*> body; body.reserve(al, 1);
Vec<char*> dep; dep.reserve(al, 1);
vecbasic_get_sym = ASR::down_cast<ASR::symbol_t>(
ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, name),
dep.p, dep.n, args.p, args.n, body.p, body.n, nullptr,
ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header)));
current_scope->parent->add_symbol(s2c(al, name), vecbasic_get_sym);
}
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 3);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = x;
call_args.push_back(al, call_arg);
call_arg.m_value = y;
call_args.push_back(al, call_arg);
call_arg.m_value = z;
call_args.push_back(al, call_arg);
return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, vecbasic_get_sym,
vecbasic_get_sym, call_args.p, call_args.n, nullptr));
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -641,7 +694,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicGetArgument: {
// Define necessary function symbols
ASR::expr_t* value1 = handle_argument(al, loc, x->m_args[0]);
ASR::symbol_t* vecbasic_get_sym = declare_vecbasic_get_function(al, loc, module_scope);
ASR::symbol_t* vecbasic_size_sym = declare_vecbasic_size_function(al, loc, module_scope);

// Define necessary variables
Expand Down Expand Up @@ -682,21 +734,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
pass_result.push_back(al, stmt3);

// Statement 4
Vec<ASR::call_arg_t> call_args4;
call_args4.reserve(al, 3);
ASR::call_arg_t call_arg4, call_arg5, call_arg6;
call_arg4.loc = loc;
call_arg4.m_value = args;
call_arg5.loc = loc;
call_arg5.m_value = x->m_args[1];
call_arg6.loc = loc;
call_arg6.m_value = target;
call_args4.push_back(al, call_arg4);
call_args4.push_back(al, call_arg5);
call_args4.push_back(al, call_arg6);
ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, vecbasic_get_sym,
vecbasic_get_sym, call_args4.p, call_args4.n, nullptr));
pass_result.push_back(al, stmt4);
pass_result.push_back(al, vecbasic_get(loc, args, x->m_args[1], target));
break;
}
default: {
Expand Down Expand Up @@ -863,51 +901,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_vecbasic_get_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "vecbasic_get";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 3);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE((ASR::make_Integer_t(al, loc, 4))),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "z"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "z"), arg3);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_vecbasic_size_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "vecbasic_size";
symbolic_dependencies.push_back(name);
Expand Down

0 comments on commit 799932e

Please sign in to comment.