Skip to content

Commit

Permalink
[ASR Pass] Symbolic: Simplify basic_const to return SubroutineCall
Browse files Browse the repository at this point in the history
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
1 parent 2745451 commit 9440150
Showing 1 changed file with 40 additions and 51 deletions.
91 changes: 40 additions & 51 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,44 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, symbol_set_sym,
symbol_set_sym, call_args.p, call_args.n, nullptr));
}

ASR::stmt_t *basic_const(const Location &loc,
const std::string &fn_name, ASR::expr_t* value) {
symbolic_dependencies.push_back(fn_name);
ASR::symbol_t *basic_const_sym = current_scope->resolve_symbol(fn_name);
if ( !basic_const_sym ) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(current_scope->parent);

Vec<ASR::expr_t*> args; args.reserve(al, 1);
ASR::symbol_t* arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg)));

Vec<ASR::stmt_t*> body; body.reserve(al, 1);
Vec<char*> dep; dep.reserve(al, 1);
basic_const_sym = ASR::down_cast<ASR::symbol_t>(
ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name),
dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header)));
current_scope->parent->add_symbol(s2c(al, fn_name), basic_const_sym);
}

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value;
call_args.push_back(al, call_arg);

return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc,
basic_const_sym, basic_const_sym, call_args.p, call_args.n, nullptr));
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -824,50 +862,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
pass_result.push_back(al, stmt);
}

void perform_symbolic_constant_operation(Allocator &al, const Location &loc, SymbolTable* module_scope,
const std::string& new_name, ASR::expr_t* value) {
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
module_scope->add_symbol(s2c(al, new_name), new_symbol);
}

ASR::symbol_t* func_sym = module_scope->get_symbol(new_name);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value;
call_args.push_back(al, call_arg);

ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, func_sym,
func_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
}

ASR::expr_t* handle_argument(Allocator &al, const Location &loc, ASR::expr_t* arg) {
if (ASR::is_a<ASR::Var_t>(*arg)) {
return arg;
Expand Down Expand Up @@ -895,11 +889,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
perform_symbolic_unary_operation(al, loc, module_scope, new_name, target, value1);
}

void process_constants(Allocator &al, const Location &loc, ASR::IntrinsicScalarFunction_t* /*x*/, SymbolTable* module_scope,
const std::string& new_name, ASR::expr_t* target) {
perform_symbolic_constant_operation(al, loc, module_scope, new_name, target);
}

void process_intrinsic_function(Allocator &al, const Location &loc, ASR::IntrinsicScalarFunction_t* x, SymbolTable* module_scope,
ASR::expr_t* target){
int64_t intrinsic_id = x->m_intrinsic_id;
Expand All @@ -909,11 +898,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPi: {
process_constants(al, loc, x, module_scope, "basic_const_pi", target);
pass_result.push_back(al, basic_const(loc, "basic_const_pi", target));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicE: {
process_constants(al, loc, x, module_scope, "basic_const_E", target);
pass_result.push_back(al, basic_const(loc, "basic_const_E", target));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAdd: {
Expand Down

0 comments on commit 9440150

Please sign in to comment.