Skip to content

Commit

Permalink
[ASR Pass] Symbolic: Simplify basic_get_type to return FunctionCall
Browse files Browse the repository at this point in the history
  • Loading branch information
Thirumalai-Shaktivel committed Nov 25, 2023
1 parent 0aa4435 commit 3096c28
Showing 1 changed file with 48 additions and 89 deletions.
137 changes: 48 additions & 89 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,49 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, loc, 1, -2, nullptr)), nullptr, nullptr));
}

ASR::expr_t* basic_get_type(const Location& loc, ASR::expr_t* value) {
std::string fn_name = "basic_get_type";
symbolic_dependencies.push_back(fn_name);
ASR::symbol_t *basic_get_type_sym = current_scope->resolve_symbol(fn_name);
if ( !basic_get_type_sym ) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(current_scope->parent);

Vec<ASR::expr_t*> args; args.reserve(al, 1);
char *return_var_name =s2c(al, "_lpython_return_variable");
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, return_var_name, nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(return_var_name, arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body; body.reserve(al, 1);
Vec<char*> dep; dep.reserve(al, 1);
ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1));
basic_get_type_sym = ASR::down_cast<ASR::symbol_t>(
ASRUtils::make_Function_t_util(al, loc, fn_symtab, s2c(al, fn_name),
dep.p, dep.n, args.p, args.n, body.p, body.n, return_var,
ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header)));
current_scope->parent->add_symbol(s2c(al, fn_name), basic_get_type_sym);
}
Vec<ASR::call_arg_t> call_args; call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value;
call_args.push_back(al, call_arg);
return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
}
/********************************** Utils *********************************/

void visit_Function(const ASR::Function_t &x) {
Expand Down Expand Up @@ -910,45 +953,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_get_type_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_get_type";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
return_var, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_eq";
symbolic_dependencies.push_back(name);
Expand Down Expand Up @@ -1106,89 +1110,44 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 16, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 15, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 17, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicLogQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 29 as the right value of the IntegerCompare node as it represents SYMENGINE_LOG through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 29, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr));
break;
}
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicSinQ: {
ASR::symbol_t* basic_get_type_sym = declare_basic_get_type_function(al, loc, module_scope);
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = loc;
call_arg.m_value = value1;
call_args.push_back(al, call_arg);
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
basic_get_type_sym, basic_get_type_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)), nullptr, nullptr));
ASR::expr_t* function_call = basic_get_type(loc, value1);
// Using 35 as the right value of the IntegerCompare node as it represents SYMENGINE_SIN through SYMENGINE_ENUM
return ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, loc, function_call, ASR::cmpopType::Eq,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 35, ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)))),
Expand Down

0 comments on commit 3096c28

Please sign in to comment.