Skip to content

Commit

Permalink
[ASR Pass] Symbolic: Simplify integer_set_si to return `SubroutineC…
Browse files Browse the repository at this point in the history
…all`
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
1 parent bb48bdb commit e8724c1
Showing 1 changed file with 50 additions and 81 deletions.
131 changes: 50 additions & 81 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,53 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
basic_compare_sym, basic_compare_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
}

ASR::stmt_t* integer_set_si(const Location& loc, ASR::expr_t *target,
ASR::expr_t *value) {
std::string fn_name = "integer_set_si";
symbolic_dependencies.push_back(fn_name);
ASR::symbol_t *integer_set_si_sym = current_scope->resolve_symbol(fn_name);
if ( !integer_set_si_sym ) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(current_scope->parent);

Vec<ASR::expr_t*> args; args.reserve(al, 2);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 8)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body; body.reserve(al, 1);
Vec<char*> dep; dep.reserve(al, 1);
integer_set_si_sym = ASR::down_cast<ASR::symbol_t>(
ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name),
dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header)));
current_scope->parent->add_symbol(s2c(al, fn_name), integer_set_si_sym);
}

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
call_arg.m_value = value;
call_args.push_back(al, call_arg);

return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, integer_set_si_sym,
integer_set_si_sym, call_args.p, call_args.n, nullptr));
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -965,45 +1012,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

ASR::symbol_t* declare_integer_set_si_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "integer_set_si";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 2);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 8)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr,
SymbolTable* module_scope) {
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
Expand Down Expand Up @@ -1148,22 +1156,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* cast_arg = cast_t->m_arg;
ASR::expr_t* cast_value = cast_t->m_value;
if (ASR::is_a<ASR::Var_t>(*cast_arg)) {
ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope);
ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg,
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, nullptr));
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym,
integer_set_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
pass_result.push_back(al, integer_set_si(x.base.base.loc, x.m_target, value));
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*cast_value)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(cast_value);
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
Expand All @@ -1180,24 +1176,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
const_value = const_int->m_n;
}

ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope);
ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg,
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, const_value, cast_type))));
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);

ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym,
integer_set_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
pass_result.push_back(al, integer_set_si(x.base.base.loc, x.m_target, value));
}
}
}
Expand Down Expand Up @@ -1416,7 +1399,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

void visit_Cast(const ASR::Cast_t &x) {
if(x.m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
SymbolTable* module_scope = current_scope->parent;

ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc));
std::string symengine_var = symengine_stack.push();
Expand Down Expand Up @@ -1451,24 +1433,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
const_value = const_int->m_n;
}

ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope);
ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg,
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, const_value, cast_type))));
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);

ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym,
integer_set_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
pass_result.push_back(al, integer_set_si(x.base.base.loc, target, value));
}
}
}
Expand Down

0 comments on commit e8724c1

Please sign in to comment.