Skip to content

Commit

Permalink
Merge pull request lcompilers#2374 from anutosh491/implementing_symbo…
Browse files Browse the repository at this point in the history
…lic_comparison

Added support for comparing symbolic expressions
  • Loading branch information
certik committed Oct 8, 2023
2 parents 670b12f + 4ec4a76 commit 36fe6cf
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 1 deletion.
19 changes: 18 additions & 1 deletion integration_tests/symbolics_02.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from sympy import Symbol
from sympy import Symbol, pi
from lpython import S

def test_symbolic_operations():
x: S = Symbol('x')
y: S = Symbol('y')
p1: S = pi
p2: S = pi

# Addition
z: S = x + y
Expand Down Expand Up @@ -37,4 +39,19 @@ def test_symbolic_operations():
assert(c == S(0))
print(c)

# Comparison
b1: bool = p1 == p2
print(b1)
assert(b1 == True)
b2: bool = p1 != pi
print(b2)
assert(b2 == False)
b3: bool = p1 != x
print(b3)
assert(b3 == True)
b4: bool = pi == Symbol("x")
print(b4)
assert(b4 == False)


test_symbolic_operations()
143 changes: 143 additions & 0 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,96 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_eq";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg3);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
return_var, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::symbol_t* declare_basic_neq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) {
std::string name = "basic_neq";
symbolic_dependencies.push_back(name);
if (!module_scope->get_symbol(name)) {
std::string header = "symengine/cwrapper.h";
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);

Vec<ASR::expr_t*> args;
args.reserve(al, 1);
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "x"), arg2);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
fn_symtab->add_symbol(s2c(al, "y"), arg3);
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));

Vec<ASR::stmt_t*> body;
body.reserve(al, 1);

Vec<char*> dep;
dep.reserve(al, 1);

ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
return_var, ASR::abiType::BindC, ASR::accessType::Public,
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
false, false, nullptr, 0, false, false, false, s2c(al, header));
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
module_scope->add_symbol(s2c(al, name), symbol);
}
return module_scope->get_symbol(name);
}

ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr,
SymbolTable* module_scope) {
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*expr)) {
Expand Down Expand Up @@ -772,6 +862,33 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}
}
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_value)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_value);
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
ASR::symbol_t* sym = nullptr;
if (s->m_op == ASR::cmpopType::Eq) {
sym = declare_basic_eq_function(al, x.base.base.loc, module_scope);
} else {
sym = declare_basic_neq_function(al, x.base.base.loc, module_scope);
}
ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left);
ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right);

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = value1;
call_args.push_back(al, call_arg1);
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value2;
call_args.push_back(al, call_arg2);

ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr));
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr));
pass_result.push_back(al, stmt);
}
}
}

Expand Down Expand Up @@ -905,6 +1022,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
print_tmp.push_back(function_call);
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*val)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(val);
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
ASR::symbol_t* sym = nullptr;
if (s->m_op == ASR::cmpopType::Eq) {
sym = declare_basic_eq_function(al, x.base.base.loc, module_scope);
} else {
sym = declare_basic_neq_function(al, x.base.base.loc, module_scope);
}
ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left);
ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right);

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 1);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = value1;
call_args.push_back(al, call_arg1);
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = value2;
call_args.push_back(al, call_arg2);

ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr));
print_tmp.push_back(function_call);
}
} else {
print_tmp.push_back(x.m_values[i]);
}
Expand Down

0 comments on commit 36fe6cf

Please sign in to comment.