Skip to content

Commit

Permalink
Adding Support for symbolics in the list data structure (lcompilers#2368
Browse files Browse the repository at this point in the history
)
  • Loading branch information
anutosh491 authored Oct 10, 2023
1 parent ac65227 commit 9b5f52f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym)
RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_11 LABELS cpython_sym c_sym NOFAST)

RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
Expand Down
18 changes: 18 additions & 0 deletions integration_tests/symbolics_11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from sympy import Symbol, sin, pi
from lpython import S

def test_extraction_of_elements():
x: S = Symbol("x")
l1: list[S] = [x, pi, sin(x), Symbol("y")]
ele1: S = l1[0]
ele2: S = l1[1]
ele3: S = l1[2]
ele4: S = l1[3]

assert(ele1 == x)
assert(ele2 == pi)
assert(ele3 == sin(x))
assert(ele4 == Symbol("y"))
print(ele1, ele2, ele3, ele4)

test_extraction_of_elements()
48 changes: 48 additions & 0 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
pass_result.push_back(al, stmt3);
pass_result.push_back(al, stmt4);
}
} else if (xx.m_type->type == ASR::ttypeType::List) {
ASR::List_t* list = ASR::down_cast<ASR::List_t>(xx.m_type);
if (list->m_type->type == ASR::ttypeType::SymbolicExpression){
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, xx.base.base.loc, CPtr_type));
xx.m_type = list_type;
}
}
}

Expand Down Expand Up @@ -920,6 +927,47 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}
}
}
} else if (ASR::is_a<ASR::ListConstant_t>(*x.m_value)) {
ASR::ListConstant_t* list_constant = ASR::down_cast<ASR::ListConstant_t>(x.m_value);
if (list_constant->m_type->type == ASR::ttypeType::List) {
ASR::List_t* list = ASR::down_cast<ASR::List_t>(list_constant->m_type);
if (list->m_type->type == ASR::ttypeType::SymbolicExpression){
Vec<ASR::expr_t*> temp_list;
temp_list.reserve(al, list_constant->n_args + 1);

for (size_t i = 0; i < list_constant->n_args; ++i) {
ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]);
temp_list.push_back(al, value);
}

ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, type));
ASR::expr_t* temp_list_const = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list.p,
temp_list.size(), list_type));
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const, nullptr));
pass_result.push_back(al, stmt);
}
}
} else if (ASR::is_a<ASR::ListItem_t>(*x.m_value)) {
ASR::ListItem_t* list_item = ASR::down_cast<ASR::ListItem_t>(x.m_value);
if (list_item->m_type->type == ASR::ttypeType::SymbolicExpression) {
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope);

Vec<ASR::call_arg_t> call_args;
call_args.reserve(al, 2);
ASR::call_arg_t call_arg1, call_arg2;
call_arg1.loc = x.base.base.loc;
call_arg1.m_value = x.m_target;
call_arg2.loc = x.base.base.loc;
call_arg2.m_value = ASRUtils::EXPR(ASR::make_ListItem_t(al, x.base.base.loc, list_item->m_a,
list_item->m_pos, CPtr_type, nullptr));
call_args.push_back(al, call_arg1);
call_args.push_back(al, call_arg2);
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym,
basic_assign_sym, call_args.p, call_args.n, nullptr));
pass_result.push_back(al, stmt);
}
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_value)) {
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_value);
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {
Expand Down

0 comments on commit 9b5f52f

Please sign in to comment.