diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 1177bb5266..44b703eaeb 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -729,6 +729,7 @@ RUN(NAME test_gruntz LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_15 LABELS c_sym llvm_sym NOFAST) RUN(NAME symbolics_16 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_17 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_18 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_18.py b/integration_tests/symbolics_18.py new file mode 100644 index 0000000000..b3dd8bad2c --- /dev/null +++ b/integration_tests/symbolics_18.py @@ -0,0 +1,36 @@ +from lpython import S +from sympy import Symbol, log + +def func_01(e: S, x: S) -> S: + print(e) + if e == x: + return x + print(e) + return e + +def test_func_01(): + x: S = Symbol("x") + ans: S = func_01(log(x), x) + print(ans) + +def func_02(e: S, x: S) -> list[S]: + print(e) + if e == x: + list1: list[S] = [x] + return list1 + else: + print(e) + list2: list[S] = func_02(x, x) + return list2 + +def test_func_02(): + x: S = Symbol("x") + ans: list[S] = func_02(log(x), x) + ele: S = ans[0] + print(ele) + +def tests(): + test_func_01() + test_func_02() + +tests() \ No newline at end of file diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 2a34ef0f15..26546880f2 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -830,16 +830,6 @@ namespace LCompilers { s_func_type->m_arg_types = arg_types.p; s_func_type->n_arg_types = arg_types.n; s_func_type->m_return_var_type = nullptr; - - if (ASR::is_a(*x->m_body[x->n_body - 1])) { - Vec func_body; - func_body.reserve(al, x->n_body - 1); - for (size_t i=0; i< x->n_body - 1; i++) { - func_body.push_back(al, x->m_body[i]); - } - x->m_body = func_body.p; - x->n_body = func_body.n; - } } } for (auto &item : x->m_symtab->get_scope()) { diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index fadea2b021..d17b575b21 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -323,7 +323,6 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor((ASR::asr_t*)&xx)); } if(xx.m_intent == ASR::intentType::In){ @@ -524,7 +523,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value) && ASR::is_a(*ASRUtils::expr_type(x.m_value))) { ASR::symbol_t *v = ASR::down_cast(x.m_value)->m_v; - if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return; + if ((symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) && + (symbolic_vars_to_omit.find(v) == symbolic_vars_to_omit.end())) return; ASR::symbol_t* var_sym = ASR::down_cast(x.m_value)->m_v; pass_result.push_back(al, basic_assign(x.base.base.loc, x.m_target, ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)))); @@ -784,7 +784,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val) && ASR::is_a(*ASRUtils::expr_type(val))) { ASR::symbol_t *v = ASR::down_cast(val)->m_v; - if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return; + if ((symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) && + (symbolic_vars_to_omit.find(v) == symbolic_vars_to_omit.end())) return; print_tmp.push_back(basic_str(x.base.base.loc, val)); } else if (ASR::is_a(*val)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(val); @@ -1007,14 +1008,12 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor