diff --git a/.gitignore b/.gitignore index 89ae329fe8..0e0c4e6d87 100644 --- a/.gitignore +++ b/.gitignore @@ -188,6 +188,8 @@ integration_tests/structs_02 integration_tests/structs_02.c integration_tests/structs_03 integration_tests/structs_03.c +integration_tests/structs_05 +integration_tests/structs_05.c integration_tests/expr_08 integration_tests/expr_08.c integration_tests/expr_12 diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 339e9d4aa4..19459b94e8 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -184,6 +184,7 @@ RUN(NAME structs_01 LABELS cpython llvm c) RUN(NAME structs_02 LABELS llvm c) RUN(NAME structs_03 LABELS llvm c) RUN(NAME structs_04 LABELS cpython llvm c) +RUN(NAME structs_05 LABELS llvm c) RUN(NAME test_str_to_int LABELS cpython llvm) RUN(NAME test_platform LABELS cpython llvm) RUN(NAME test_vars_01 LABELS cpython llvm) diff --git a/integration_tests/structs_05.py b/integration_tests/structs_05.py new file mode 100644 index 0000000000..e5ec3d4dfe --- /dev/null +++ b/integration_tests/structs_05.py @@ -0,0 +1,36 @@ +from ltypes import i32, f64, dataclass + +@dataclass +class A: + y: f64 + x: i32 + +def verify(s: A[:], x1: i32, y1: f64, x2: i32, y2: f64): + eps: f64 = 1e-12 + print(s[0].x, s[0].y) + assert s[0].x == x1 + assert abs(s[0].y - y1) < eps + print(s[1].x, s[1].y) + assert s[1].x == x2 + assert abs(s[1].y - y2) < eps + +def update_1(s: A): + s.x = 2 + s.y = 1.2 + +def update_2(s: A[:]): + s[1].x = 3 + s[1].y = 2.3 + +def g(): + # TODO: Replace y: A[2] with y: A[2] = [None, None] + # TODO: And enable cpython in integration_tests. + y: A[2] + y[0] = A(1.1, 1) + y[1] = A(2.2, 2) + verify(y, 1, 1.1, 2, 2.2) + update_1(y[0]) + update_2(y) + verify(y, 2, 1.2, 3, 2.3) + +g() diff --git a/src/libasr/codegen/asr_to_c.cpp b/src/libasr/codegen/asr_to_c.cpp index 28b9cfd937..a207ac1008 100644 --- a/src/libasr/codegen/asr_to_c.cpp +++ b/src/libasr/codegen/asr_to_c.cpp @@ -157,7 +157,7 @@ class ASRToCVisitor : public BaseCCPPVisitor ASR::Derived_t *t = ASR::down_cast(v.m_type); std::string der_type_name = ASRUtils::symbol_name(t->m_derived_type); std::string dims = convert_dims_c(t->n_dims, t->m_dims); - if( v.m_intent == ASRUtils::intent_local && pre_initialise_derived_type ) { + if( v.m_intent == ASRUtils::intent_local && pre_initialise_derived_type) { std::string value_var_name = v.m_parent_symtab->get_unique_name(std::string(v.m_name) + "_value"); sub = format_type_c(dims, "struct " + der_type_name, value_var_name, use_ref, dummy); @@ -168,9 +168,18 @@ class ASRToCVisitor : public BaseCCPPVisitor } sub += ";\n"; sub += indent + format_type_c("", "struct " + der_type_name + "*", v.m_name, use_ref, dummy); - sub += "= &" + value_var_name; + if( t->n_dims != 0 ) { + sub += " = " + value_var_name; + } else { + sub += " = &" + value_var_name; + } return sub; } else { + if( v.m_intent == ASRUtils::intent_in || + v.m_intent == ASRUtils::intent_inout ) { + use_ref = false; + dims = ""; + } sub = format_type_c(dims, "struct " + der_type_name + "*", v.m_name, use_ref, dummy); } diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index 6f5b20e37d..24ca42f751 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -592,12 +592,20 @@ R"(#include this->visit_expr(*x.m_v); der_expr = std::move(src); member = ASRUtils::symbol_name(x.m_m); - src = der_expr + "->" + member; + if( ASR::is_a(*x.m_v) ) { + src = der_expr + "." + member; + } else { + src = der_expr + "->" + member; + } } void visit_ArrayItem(const ASR::ArrayItem_t &x) { const ASR::symbol_t *s = ASRUtils::symbol_get_past_external(x.m_v); ASR::Variable_t* sv = ASR::down_cast(s); + std::string prefix = ""; + // if( ASR::is_a(*sv->m_type) ) { + // prefix = "&"; + // } std::string out = std::string(sv->m_name); if( (sv->m_intent == ASRUtils::intent_in || sv->m_intent == ASRUtils::intent_inout) && @@ -619,6 +627,11 @@ R"(#include } out += "]"; last_expr_precedence = 2; + // if( !prefix.empty() ) { + // src = prefix + "(" + out + ")"; + // } else { + // src = out; + // } src = out; } @@ -1178,7 +1191,12 @@ R"(#include } } else { self().visit_expr(*x.m_args[i].m_value); - out += src; + if( ASR::is_a(*x.m_args[i].m_value) && + ASR::is_a(*ASRUtils::expr_type(x.m_args[i].m_value)) ) { + out += "&" + src; + } else { + out += src; + } } if (i < x.n_args-1) out += ", "; } diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index ea535397ff..74d8cd68d0 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1155,6 +1155,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor return; } ASR::Variable_t *v = ASR::down_cast(x.m_v); + if( ASR::is_a(*v->m_type) ) { + ASR::Derived_t* der_type = ASR::down_cast(v->m_type); + der_type_name = ASRUtils::symbol_name(ASRUtils::symbol_get_past_external(der_type->m_derived_type)); + } uint32_t v_h = get_hash((ASR::asr_t*)v); LFORTRAN_ASSERT(llvm_symtab.find(v_h) != llvm_symtab.end()); llvm::Value* array = llvm_symtab[v_h]; diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 9ceb53bd56..489e56e4f6 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -2836,6 +2836,11 @@ class BodyVisitor : public CommonVisitor { visit_Attribute(*x_m_value); ASR::expr_t* e = ASRUtils::EXPR(tmp); visit_AttributeUtil(ASRUtils::expr_type(e), x.m_attr, e, x.base.base.loc); + } else if(AST::is_a(*x.m_value)) { + AST::Subscript_t* x_m_value = AST::down_cast(x.m_value); + visit_Subscript(*x_m_value); + ASR::expr_t* e = ASRUtils::EXPR(tmp); + visit_AttributeUtil(ASRUtils::expr_type(e), x.m_attr, e, x.base.base.loc); } else { throw SemanticError("Only Name, Attribute is supported for now in Attribute", x.base.base.loc); diff --git a/tests/reference/asr-structs_05-fa98307.json b/tests/reference/asr-structs_05-fa98307.json new file mode 100644 index 0000000000..15588e5d2d --- /dev/null +++ b/tests/reference/asr-structs_05-fa98307.json @@ -0,0 +1,13 @@ +{ + "basename": "asr-structs_05-fa98307", + "cmd": "lpython --show-asr --no-color {infile} -o {outfile}", + "infile": "tests/../integration_tests/structs_05.py", + "infile_hash": "8022943f9e2faac2fac4bd147f9ccba284127429215ddef38d22beb8", + "outfile": null, + "outfile_hash": null, + "stdout": "asr-structs_05-fa98307.stdout", + "stdout_hash": "5611b3927900d1071ff9b00de1306cc5baeba9fefd0947d49552336e", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/asr-structs_05-fa98307.stdout b/tests/reference/asr-structs_05-fa98307.stdout new file mode 100644 index 0000000000..39a1f094be --- /dev/null +++ b/tests/reference/asr-structs_05-fa98307.stdout @@ -0,0 +1 @@ +(TranslationUnit (SymbolTable 1 {A: (DerivedType (SymbolTable 2 {x: (Variable 2 x Local () () Default (Integer 4 []) Source Public Required .false.), y: (Variable 2 y Local () () Default (Real 8 []) Source Public Required .false.)}) A [y x] Source Public ()), _lpython_main_program: (Subroutine (SymbolTable 84 {}) _lpython_main_program [] [(SubroutineCall 1 g () [] ())] Source Public Implementation () .false. .false.), abs@__lpython_overloaded_0__abs: (ExternalSymbol 1 abs@__lpython_overloaded_0__abs 8 __lpython_overloaded_0__abs lpython_builtin [] __lpython_overloaded_0__abs Public), g: (Subroutine (SymbolTable 6 {y: (Variable 6 y Local () () Default (Derived 1 A [((IntegerConstant 0 (Integer 4 [])) (IntegerBinOp (IntegerConstant 2 (Integer 4 [])) Sub (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) (IntegerConstant 1 (Integer 4 []))))]) Source Public Required .false.)}) g [] [(= (ArrayItem 6 y [(() (IntegerConstant 0 (Integer 4 [])) ())] (Derived 1 A [((IntegerConstant 0 (Integer 4 [])) (IntegerBinOp (IntegerConstant 2 (Integer 4 [])) Sub (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) (IntegerConstant 1 (Integer 4 []))))]) ()) (DerivedTypeConstructor 1 A [(RealConstant 1.10000000000000009e+00 (Real 8 [])) (IntegerConstant 1 (Integer 4 []))] (Derived 1 A []) ()) ()) (= (ArrayItem 6 y [(() (IntegerConstant 1 (Integer 4 [])) ())] (Derived 1 A [((IntegerConstant 0 (Integer 4 [])) (IntegerBinOp (IntegerConstant 2 (Integer 4 [])) Sub (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) (IntegerConstant 1 (Integer 4 []))))]) ()) (DerivedTypeConstructor 1 A [(RealConstant 2.20000000000000018e+00 (Real 8 [])) (IntegerConstant 2 (Integer 4 []))] (Derived 1 A []) ()) ()) (SubroutineCall 1 verify () [((Var 6 y)) ((IntegerConstant 1 (Integer 4 []))) ((RealConstant 1.10000000000000009e+00 (Real 8 []))) ((IntegerConstant 2 (Integer 4 []))) ((RealConstant 2.20000000000000018e+00 (Real 8 [])))] ()) (SubroutineCall 1 update_1 () [((ArrayItem 6 y [(() (IntegerConstant 0 (Integer 4 [])) ())] (Derived 1 A [((IntegerConstant 0 (Integer 4 [])) (IntegerBinOp (IntegerConstant 2 (Integer 4 [])) Sub (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) (IntegerConstant 1 (Integer 4 []))))]) ()))] ()) (SubroutineCall 1 update_2 () [((Var 6 y))] ()) (SubroutineCall 1 verify () [((Var 6 y)) ((IntegerConstant 2 (Integer 4 []))) ((RealConstant 1.19999999999999996e+00 (Real 8 []))) ((IntegerConstant 3 (Integer 4 []))) ((RealConstant 2.29999999999999982e+00 (Real 8 [])))] ())] Source Public Implementation () .false. .false.), lpython_builtin: (IntrinsicModule lpython_builtin), main_program: (Program (SymbolTable 83 {}) main_program [] [(SubroutineCall 1 _lpython_main_program () [] ())]), update_1: (Subroutine (SymbolTable 4 {s: (Variable 4 s In () () Default (Derived 1 A []) Source Public Required .false.)}) update_1 [(Var 4 s)] [(= (DerivedRef (Var 4 s) 2 x (Integer 4 []) ()) (IntegerConstant 2 (Integer 4 [])) ()) (= (DerivedRef (Var 4 s) 2 y (Real 8 []) ()) (RealConstant 1.19999999999999996e+00 (Real 8 [])) ())] Source Public Implementation () .false. .false.), update_2: (Subroutine (SymbolTable 5 {s: (Variable 5 s InOut () () Default (Derived 1 A [(() ())]) Source Public Required .false.)}) update_2 [(Var 5 s)] [(= (DerivedRef (ArrayItem 5 s [(() (IntegerConstant 1 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 x (Integer 4 []) ()) (IntegerConstant 3 (Integer 4 [])) ()) (= (DerivedRef (ArrayItem 5 s [(() (IntegerConstant 1 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 y (Real 8 []) ()) (RealConstant 2.29999999999999982e+00 (Real 8 [])) ())] Source Public Implementation () .false. .false.), verify: (Subroutine (SymbolTable 3 {abs: (ExternalSymbol 3 abs 8 abs lpython_builtin [] abs Private), eps: (Variable 3 eps Local () () Default (Real 8 []) Source Public Required .false.), s: (Variable 3 s InOut () () Default (Derived 1 A [(() ())]) Source Public Required .false.), x1: (Variable 3 x1 In () () Default (Integer 4 []) Source Public Required .false.), x2: (Variable 3 x2 In () () Default (Integer 4 []) Source Public Required .false.), y1: (Variable 3 y1 In () () Default (Real 8 []) Source Public Required .false.), y2: (Variable 3 y2 In () () Default (Real 8 []) Source Public Required .false.)}) verify [(Var 3 s) (Var 3 x1) (Var 3 y1) (Var 3 x2) (Var 3 y2)] [(= (Var 3 eps) (RealConstant 9.99999999999999980e-13 (Real 8 [])) ()) (Print () [(DerivedRef (ArrayItem 3 s [(() (IntegerConstant 0 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 x (Integer 4 []) ()) (DerivedRef (ArrayItem 3 s [(() (IntegerConstant 0 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 y (Real 8 []) ())] () ()) (Assert (IntegerCompare (DerivedRef (ArrayItem 3 s [(() (IntegerConstant 0 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 x (Integer 4 []) ()) Eq (Var 3 x1) (Logical 4 []) ()) ()) (Assert (RealCompare (FunctionCall 1 abs@__lpython_overloaded_0__abs 3 abs [((RealBinOp (DerivedRef (ArrayItem 3 s [(() (IntegerConstant 0 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 y (Real 8 []) ()) Sub (Var 3 y1) (Real 8 []) ()))] (Real 8 []) () ()) Lt (Var 3 eps) (Logical 4 []) ()) ()) (Print () [(DerivedRef (ArrayItem 3 s [(() (IntegerConstant 1 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 x (Integer 4 []) ()) (DerivedRef (ArrayItem 3 s [(() (IntegerConstant 1 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 y (Real 8 []) ())] () ()) (Assert (IntegerCompare (DerivedRef (ArrayItem 3 s [(() (IntegerConstant 1 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 x (Integer 4 []) ()) Eq (Var 3 x2) (Logical 4 []) ()) ()) (Assert (RealCompare (FunctionCall 1 abs@__lpython_overloaded_0__abs 3 abs [((RealBinOp (DerivedRef (ArrayItem 3 s [(() (IntegerConstant 1 (Integer 4 [])) ())] (Derived 1 A [(() ())]) ()) 2 y (Real 8 []) ()) Sub (Var 3 y2) (Real 8 []) ()))] (Real 8 []) () ()) Lt (Var 3 eps) (Logical 4 []) ()) ())] Source Public Implementation () .false. .false.)}) []) diff --git a/tests/tests.toml b/tests/tests.toml index 070cb29674..a465074827 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -251,6 +251,10 @@ asr = true filename = "../integration_tests/structs_04.py" asr = true +[[test]] +filename = "../integration_tests/structs_05.py" +asr = true + [[test]] filename = "../integration_tests/bindc_01.py" asr = true