Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR Pass] Symbolic: Simplify the symbolic handling #2431

Merged
merged 21 commits into from
Nov 26, 2023
Merged
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
dbb001f
[ASR Pass] Symbolic: Use a function to create `basic_new_stack` BindC…
Thirumalai-Shaktivel Nov 25, 2023
ed2997d
[ASR Pass] Symbolic: Use a function to create `basic_free_stack` Bind…
Thirumalai-Shaktivel Nov 25, 2023
2fe8807
[ASR Pass] Symbolic: Add `basic_free_stack` to function dependencies
Thirumalai-Shaktivel Nov 25, 2023
0c3d6f2
[ASR Pass] Symbolic: Simplify `basic_get_args` to return `SubroutineC…
Thirumalai-Shaktivel Nov 25, 2023
93546df
[ASR Pass] Symbolic: Simplify `vecbasic_new` to return `FunctionCall`
Thirumalai-Shaktivel Nov 25, 2023
799932e
[ASR Pass] Symbolic: Simplify `vecbasic_get` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
a7eae7b
[ASR Pass] Symbolic: Simplify `vecbasic_size` to return `FunctionCall`
Thirumalai-Shaktivel Nov 25, 2023
f497d07
[ASR Pass] Symbolic: Simplify `basic_assign` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
0aa4435
[ASR Pass] Symbolic: Simplify `basic_str` to return `FunctionCall`
Thirumalai-Shaktivel Nov 25, 2023
3096c28
[ASR Pass] Symbolic: Simplify `basic_get_type` to return `FunctionCall`
Thirumalai-Shaktivel Nov 25, 2023
bb48bdb
[ASR Pass] Symbolic: Simplify `basic_eq` & `basic_neq` into `basic_co…
Thirumalai-Shaktivel Nov 25, 2023
e8724c1
[ASR Pass] Symbolic: Simplify `integer_set_si` to return `SubroutineC…
Thirumalai-Shaktivel Nov 25, 2023
2745451
[ASR Pass] Symbolic: Simplify `symbol_set` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
9440150
[ASR Pass] Symbolic: Simplify `basic_const` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
3903658
[ASR Pass] Symbolic: Simplify `basic_binop` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
79066ee
[ASR Pass] Symbolic: Simplify `basic_unaryop` to return `SubroutineCall`
Thirumalai-Shaktivel Nov 25, 2023
d331f27
[ASR Pass] Symbolic: Simplify `process_intrinsic_function` arguments
Thirumalai-Shaktivel Nov 25, 2023
f18ae18
[ASR Pass] Symbolic: Simplify `process_intrinsic_function` to use macros
Thirumalai-Shaktivel Nov 25, 2023
760380a
[ASR Pass] Symbolic: Simplify `process_attributes` to use macros
Thirumalai-Shaktivel Nov 25, 2023
f6d0bd6
[ASR Pass] Symbolic: Simplify `basic_has_symbol` to return `FunctionC…
Thirumalai-Shaktivel Nov 25, 2023
6a7b2cd
[ASR Pass] Symbolic: Simplify `process_attributes` arguments
Thirumalai-Shaktivel Nov 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[ASR Pass] Symbolic: Simplify integer_set_si to return `SubroutineC…
…all`
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
commit e8724c15c21a711849359aba42e9457119be3717
131 changes: 50 additions & 81 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,53 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
basic_compare_sym, basic_compare_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
}

ASR::stmt_t* integer_set_si(const Location& loc, ASR::expr_t *target,
ASR::expr_t *value) {
std::string fn_name = "integer_set_si";
symbolic_dependencies.push_back(fn_name);
ASR::symbol_t *integer_set_si_sym = current_scope->resolve_symbol(fn_name);
if ( !integer_set_si_sym ) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(current_scope->parent);

Vec<ASR::expr_t*> args; args.reserve(al, 2);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 8)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body; body.reserve(al, 1);
Vec<char*> dep; dep.reserve(al, 1);
integer_set_si_sym = ASR::down_cast<ASR::symbol_t>(
ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name),
dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header)));
current_scope->parent->add_symbol(s2c(al, fn_name), integer_set_si_sym);
}

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = target;
call_args.push_back(al, call_arg);
call_arg.m_value = value;
call_args.push_back(al, call_arg);

return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, integer_set_si_sym,
integer_set_si_sym, call_args.p, call_args.n, nullptr));
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -965,45 +1012,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}

ASR::symbol_t* declare_integer_set_si_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "integer_set_si";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 2);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg1);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1)));
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 8)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr,
SymbolTable* module_scope) {
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
Expand Down Expand Up @@ -1148,22 +1156,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* cast_arg = cast_t->m_arg;
ASR::expr_t* cast_value = cast_t->m_value;
if (ASR::is_a<ASR::Var_t>(*cast_arg)) {
ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope);
ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg,
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, nullptr));
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym,
integer_set_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
pass_result.push_back(al, integer_set_si(x.base.base.loc, x.m_target, value));
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*cast_value)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(cast_value);
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
Expand All @@ -1180,24 +1176,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
const_value = const_int->m_n;
}

ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope);
ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg,
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, const_value, cast_type))));
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);

ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym,
integer_set_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
pass_result.push_back(al, integer_set_si(x.base.base.loc, x.m_target, value));
}
}
}
Expand Down Expand Up @@ -1416,7 +1399,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

void visit_Cast(const ASR::Cast_t &x) {
if(x.m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
SymbolTable* module_scope = current_scope->parent;

ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc));
std::string symengine_var = symengine_stack.push();
Expand Down Expand Up @@ -1451,24 +1433,11 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
const_value = const_int->m_n;
}

ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope);
ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg,
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, const_value, cast_type))));
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);

ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym,
integer_set_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
pass_result.push_back(al, integer_set_si(x.base.base.loc, target, value));
}
}
}
Expand Down