Skip to content

Commit

Permalink
Merge pull request lcompilers#2569 from anutosh491/Fixing_returning_v…
Browse files Browse the repository at this point in the history
…ariables

Fixing issues with freeing variables
  • Loading branch information
certik committed Mar 6, 2024
2 parents ae1bcd5 + 9a07c85 commit 55c145b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 17 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ RUN(NAME test_gruntz LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_15 LABELS c_sym llvm_sym NOFAST)
RUN(NAME symbolics_16 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_17 LABELS cpython_sym c_sym llvm_sym NOFAST)
RUN(NAME symbolics_18 LABELS cpython_sym c_sym llvm_sym NOFAST)

RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
Expand Down
36 changes: 36 additions & 0 deletions integration_tests/symbolics_18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from lpython import S
from sympy import Symbol, log

def func_01(e: S, x: S) -> S:
print(e)
if e == x:
return x
print(e)
return e

def test_func_01():
x: S = Symbol("x")
ans: S = func_01(log(x), x)
print(ans)

def func_02(e: S, x: S) -> list[S]:
print(e)
if e == x:
list1: list[S] = [x]
return list1
else:
print(e)
list2: list[S] = func_02(x, x)
return list2

def test_func_02():
x: S = Symbol("x")
ans: list[S] = func_02(log(x), x)
ele: S = ans[0]
print(ele)

def tests():
test_func_01()
test_func_02()

tests()
10 changes: 0 additions & 10 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -830,16 +830,6 @@ namespace LCompilers {
s_func_type->m_arg_types = arg_types.p;
s_func_type->n_arg_types = arg_types.n;
s_func_type->m_return_var_type = nullptr;

if (ASR::is_a<ASR::Return_t>(*x->m_body[x->n_body - 1])) {
Vec<ASR::stmt_t*> func_body;
func_body.reserve(al, x->n_body - 1);
for (size_t i=0; i< x->n_body - 1; i++) {
func_body.push_back(al, x->m_body[i]);
}
x->m_body = func_body.p;
x->n_body = func_body.n;
}
}
}
for (auto &item : x->m_symtab->get_scope()) {
Expand Down
13 changes: 6 additions & 7 deletions src/libasr/pass/replace_symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body);

for (ASR::symbol_t* symbol : symbolic_vars_to_free) {
if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue;
func_body.push_back(al, basic_free_stack(x.base.base.loc,
ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol))));
}
Expand Down Expand Up @@ -352,7 +351,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi

ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
xx.m_type = CPtr_type;
if (var_name != "_lpython_return_variable" && xx.m_intent != ASR::intentType::Out) {
if (xx.m_intent == ASR::intentType::Local) {
symbolic_vars_to_free.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
}
if(xx.m_intent == ASR::intentType::In){
Expand Down Expand Up @@ -524,7 +523,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
void visit_Assignment(const ASR::Assignment_t &x) {
if (ASR::is_a<ASR::Var_t>(*x.m_value) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(x.m_value))) {
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(x.m_value)->m_v;
if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return;
if ((symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) &&
(symbolic_vars_to_omit.find(v) == symbolic_vars_to_omit.end())) return;
ASR::symbol_t* var_sym = ASR::down_cast<ASR::Var_t>(x.m_value)->m_v;
pass_result.push_back(al, basic_assign(x.base.base.loc, x.m_target,
ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym))));
Expand Down Expand Up @@ -784,7 +784,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
ASR::expr_t* val = x.m_values[i];
if (ASR::is_a<ASR::Var_t>(*val) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(val))) {
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(val)->m_v;
if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return;
if ((symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) &&
(symbolic_vars_to_omit.find(v) == symbolic_vars_to_omit.end())) return;
print_tmp.push_back(basic_str(x.base.base.loc, val));
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val)) {
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(val);
Expand Down Expand Up @@ -1007,14 +1008,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
}

void visit_Return(const ASR::Return_t &x) {
// freeing out variables
if (!symbolic_vars_to_free.empty()){
for (ASR::symbol_t* symbol : symbolic_vars_to_free) {
if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue;
// freeing out variables
pass_result.push_back(al, basic_free_stack(x.base.base.loc,
ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol))));
}
symbolic_vars_to_free.clear();
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Return_t(al, x.base.base.loc)));
}
}
Expand Down

0 comments on commit 55c145b

Please sign in to comment.