Skip to content

Commit

Permalink
Merge pull request lcompilers#2331 from anutosh491/Fixing_symbolic_pa…
Browse files Browse the repository at this point in the history
…rameters

Added support for functions to accept symbolic variables
  • Loading branch information
certik authored Sep 24, 2023
2 parents 2293972 + be85f02 commit f9b09dd
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 113 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ RUN(NAME symbolics_05 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_06 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym)
RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST)

RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
Expand Down
17 changes: 17 additions & 0 deletions integration_tests/symbolics_09.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from sympy import Symbol, pi, S
from lpython import S, i32

def addInteger(x: S, y: S, z: S, i: i32):
_i: S = S(i)
print(x + y + z + _i)

def call_addInteger():
a: S = Symbol("x")
b: S = Symbol("y")
c: S = pi
addInteger(a, b, c, 2)

def main0():
call_addInteger()

main0()
260 changes: 147 additions & 113 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
pass_result.reserve(al, 1);
}
std::vector<std::string> symbolic_dependencies;
std::set<ASR::symbol_t*> symbolic_vars;
std::set<ASR::symbol_t*> symbolic_vars_to_free;
std::set<ASR::symbol_t*> symbolic_vars_to_omit;
SymEngine_Stack symengine_stack;

void visit_Function(const ASR::Function_t &x) {
Expand All @@ -55,6 +56,16 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
SymbolTable* current_scope_copy = this->current_scope;
this->current_scope = xx.m_symtab;
SymbolTable* module_scope = this->current_scope->parent;

ASR::ttype_t* f_signature= xx.m_function_signature;
ASR::FunctionType_t *f_type = ASR::down_cast<ASR::FunctionType_t>(f_signature);
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
for (size_t i = 0; i < f_type->n_arg_types; ++i) {
if (f_type->m_arg_types[i]->type == ASR::ttypeType::SymbolicExpression) {
f_type->m_arg_types[i] = type1;
}
}

for (auto &item : x.m_symtab->get_scope()) {
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
Expand Down Expand Up @@ -83,7 +94,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
Vec<ASR::stmt_t*> func_body;
func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body);

for (ASR::symbol_t* symbol : symbolic_vars) {
for (ASR::symbol_t* symbol : symbolic_vars_to_free) {
if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue;
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
Expand All @@ -97,7 +109,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

xx.n_body = func_body.size();
xx.m_body = func_body.p;
symbolic_vars.clear();
symbolic_vars_to_free.clear();
}

void visit_Variable(const ASR::Variable_t& x) {
Expand All @@ -109,125 +121,130 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
xx.m_type = type1;
symbolic_vars.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));

ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 8));
ASR::symbol_t* sym2 = ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(al, xx.base.base.loc, current_scope,
s2c(al, placeholder), nullptr, 0,
xx.m_intent, nullptr,
nullptr, xx.m_storage,
type2, nullptr, xx.m_abi,
xx.m_access, xx.m_presence,
xx.m_value_attr));

current_scope->add_symbol(s2c(al, placeholder), sym2);

std::string new_name = "basic_new_stack";
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
{
args.reserve(al, 1);
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
}
symbolic_vars_to_free.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
if(xx.m_intent == ASR::intentType::In){
symbolic_vars_to_omit.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
}

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);
if(xx.m_intent == ASR::intentType::Local){
ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 8));
ASR::symbol_t* sym2 = ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(al, xx.base.base.loc, current_scope,
s2c(al, placeholder), nullptr, 0,
xx.m_intent, nullptr,
nullptr, xx.m_storage,
type2, nullptr, xx.m_abi,
xx.m_access, xx.m_presence,
xx.m_value_attr));

Vec<char *> dep;
dep.reserve(al, 1);
current_scope->add_symbol(s2c(al, placeholder), sym2);

ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
module_scope->add_symbol(new_name, new_symbol);
}
std::string new_name = "basic_new_stack";
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
{
args.reserve(al, 1);
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
}

new_name = "basic_free_stack";
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);
Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<ASR::expr_t*> args;
{
args.reserve(al, 1);
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
Vec<char *> dep;
dep.reserve(al, 1);

ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
module_scope->add_symbol(new_name, new_symbol);
}

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);
new_name = "basic_free_stack";
symbolic_dependencies.push_back(new_name);
if (!module_scope->get_symbol(new_name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<char *> dep;
dep.reserve(al, 1);
Vec<ASR::expr_t*> args;
{
args.reserve(al, 1);
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
}

ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
module_scope->add_symbol(new_name, new_symbol);
}
Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

ASR::symbol_t* var_sym = current_scope->get_symbol(var_name);
ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder);
ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym));
ASR::expr_t* target2 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, var_sym));

// statement 1
ASR::expr_t* value1 = ASRUtils::EXPR(ASR::make_Cast_t(al, xx.base.base.loc,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0,
ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 4)))),
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, type2,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2))));

// statement 2
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type1));

// statement 3
ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc,
target1, ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2)), nullptr));
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node,
type1, nullptr));

// statement 4
ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol("basic_new_stack");
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = xx.base.base.loc;
call_arg.m_value = target2;
call_args.push_back(al, call_arg);
Vec<char *> dep;
dep.reserve(al, 1);

// defining the assignment statement
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr));
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr));
ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym,
basic_new_stack_sym, call_args.p, call_args.n, nullptr));

pass_result.push_back(al, stmt1);
pass_result.push_back(al, stmt2);
pass_result.push_back(al, stmt3);
pass_result.push_back(al, stmt4);
ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
module_scope->add_symbol(new_name, new_symbol);
}

ASR::symbol_t* var_sym = current_scope->get_symbol(var_name);
ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder);
ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym));
ASR::expr_t* target2 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, var_sym));

// statement 1
ASR::expr_t* value1 = ASRUtils::EXPR(ASR::make_Cast_t(al, xx.base.base.loc,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0,
ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 4)))),
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, type2,
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2))));

// statement 2
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type1));

// statement 3
ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc,
target1, ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2)), nullptr));
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node,
type1, nullptr));

// statement 4
ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol("basic_new_stack");
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg;
call_arg.loc = xx.base.base.loc;
call_arg.m_value = target2;
call_args.push_back(al, call_arg);

// defining the assignment statement
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr));
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr));
ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym,
basic_new_stack_sym, call_args.p, call_args.n, nullptr));

pass_result.push_back(al, stmt1);
pass_result.push_back(al, stmt2);
pass_result.push_back(al, stmt3);
pass_result.push_back(al, stmt4);
}
}
}

Expand Down Expand Up @@ -621,7 +638,24 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
if (cast_t->m_kind == ASR::cast_kindType::IntegerToSymbolicExpression) {
ASR::expr_t* cast_arg = cast_t->m_arg;
ASR::expr_t* cast_value = cast_t->m_value;
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*cast_value)) {
if (ASR::is_a<ASR::Var_t>(*cast_arg)) {
ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope);
ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg,
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, nullptr));
Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value;
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym,
integer_set_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*cast_value)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(cast_value);
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
if (static_cast<LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id) ==
Expand Down Expand Up @@ -668,7 +702,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* val = x.m_values[i];
if (ASR::is_a<ASR::Var_t>(*val) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(val))) {
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(val)->m_v;
if (symbolic_vars.find(v) == symbolic_vars.end()) return;
if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return;
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);

// Extract the symbol from value (Var)
Expand Down

0 comments on commit f9b09dd

Please sign in to comment.