From 390d34a4bb4e81e7865a57d34ee19d43714ce7d5 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 12 Nov 2023 13:30:56 +0530 Subject: [PATCH 1/3] Added support for Out[S] --- src/libasr/pass/pass_utils.h | 9 +++++---- src/libasr/pass/replace_symbolic.cpp | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 828d6033a9..af135698de 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -116,7 +116,8 @@ namespace LCompilers { static inline bool is_aggregate_or_array_type(ASR::expr_t* var) { return (ASR::is_a(*ASRUtils::expr_type(var)) || - ASRUtils::is_array(ASRUtils::expr_type(var))); + ASRUtils::is_array(ASRUtils::expr_type(var)) || + ASR::is_a(*ASRUtils::expr_type(var))); } template @@ -775,7 +776,7 @@ namespace LCompilers { } static inline void handle_fn_return_var(Allocator &al, ASR::Function_t *x, - bool (*is_array_or_struct)(ASR::expr_t*)) { + bool (*is_array_or_struct_or_symbolic)(ASR::expr_t*)) { if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindPython) { return; } @@ -787,7 +788,7 @@ namespace LCompilers { * in avoiding deep copies and the destination memory directly gets * filled inside the function. */ - if( is_array_or_struct(x->m_return_var)) { + if( is_array_or_struct_or_symbolic(x->m_return_var)) { for( auto& s_item: x->m_symtab->get_scope() ) { ASR::symbol_t* curr_sym = s_item.second; if( curr_sym->type == ASR::symbolType::Variable ) { @@ -826,7 +827,7 @@ namespace LCompilers { for (auto &item : x->m_symtab->get_scope()) { if (ASR::is_a(*item.second)) { handle_fn_return_var(al, ASR::down_cast( - item.second), is_array_or_struct); + item.second), is_array_or_struct_or_symbolic); } } } diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 4b856c2fbe..7abf80f8fd 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -123,7 +123,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor((ASR::asr_t*)&xx)); } if(xx.m_intent == ASR::intentType::In){ From 1c884963a49c6a7aa7d5f16d3d49a02104a4aff0 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sun, 12 Nov 2023 13:30:56 +0530 Subject: [PATCH 2/3] Added support for Out[S] --- src/libasr/pass/pass_utils.h | 9 +++++---- src/libasr/pass/replace_symbolic.cpp | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 828d6033a9..af135698de 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -116,7 +116,8 @@ namespace LCompilers { static inline bool is_aggregate_or_array_type(ASR::expr_t* var) { return (ASR::is_a(*ASRUtils::expr_type(var)) || - ASRUtils::is_array(ASRUtils::expr_type(var))); + ASRUtils::is_array(ASRUtils::expr_type(var)) || + ASR::is_a(*ASRUtils::expr_type(var))); } template @@ -775,7 +776,7 @@ namespace LCompilers { } static inline void handle_fn_return_var(Allocator &al, ASR::Function_t *x, - bool (*is_array_or_struct)(ASR::expr_t*)) { + bool (*is_array_or_struct_or_symbolic)(ASR::expr_t*)) { if (ASRUtils::get_FunctionType(x)->m_abi == ASR::abiType::BindPython) { return; } @@ -787,7 +788,7 @@ namespace LCompilers { * in avoiding deep copies and the destination memory directly gets * filled inside the function. */ - if( is_array_or_struct(x->m_return_var)) { + if( is_array_or_struct_or_symbolic(x->m_return_var)) { for( auto& s_item: x->m_symtab->get_scope() ) { ASR::symbol_t* curr_sym = s_item.second; if( curr_sym->type == ASR::symbolType::Variable ) { @@ -826,7 +827,7 @@ namespace LCompilers { for (auto &item : x->m_symtab->get_scope()) { if (ASR::is_a(*item.second)) { handle_fn_return_var(al, ASR::down_cast( - item.second), is_array_or_struct); + item.second), is_array_or_struct_or_symbolic); } } } diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 4b856c2fbe..7abf80f8fd 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -123,7 +123,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor((ASR::asr_t*)&xx)); } if(xx.m_intent == ASR::intentType::In){ From 1aa4522211b54898b0b77de77d5a51852824914b Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 13 Nov 2023 10:11:31 +0530 Subject: [PATCH 3/3] Added tests --- integration_tests/CMakeLists.txt | 1 + integration_tests/symbolics_13.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) create mode 100644 integration_tests/symbolics_13.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index f58a9841bd..1051514f43 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -716,6 +716,7 @@ RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_11 LABELS cpython_sym c_sym NOFAST) RUN(NAME symbolics_12 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_13 LABELS cpython_sym c_sym NOFAST) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_13.py b/integration_tests/symbolics_13.py new file mode 100644 index 0000000000..06f2c27599 --- /dev/null +++ b/integration_tests/symbolics_13.py @@ -0,0 +1,12 @@ +from lpython import S +from sympy import pi, Symbol + +def func() -> S: + return pi + +def test_func(): + z: S = func() + print(z) + assert z == pi + +test_func() \ No newline at end of file