Skip to content

Commit

Permalink
Merge pull request lcompilers#2056 from Smit-create/i-1981-1
Browse files Browse the repository at this point in the history
Store initializer exprs for Structs in ASR
  • Loading branch information
certik committed Jun 30, 2023
2 parents 32273ce + bdbcee7 commit e642072
Show file tree
Hide file tree
Showing 26 changed files with 334 additions and 182 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ RUN(NAME structs_28 LABELS cpython llvm c)
RUN(NAME structs_29 LABELS cpython llvm)
RUN(NAME structs_30 LABELS cpython llvm c)
RUN(NAME structs_31 LABELS cpython llvm c)
RUN(NAME structs_32 LABELS cpython llvm c)

RUN(NAME symbolics_01 LABELS cpython_sym c_sym)
RUN(NAME symbolics_02 LABELS cpython_sym c_sym)
Expand Down
8 changes: 4 additions & 4 deletions integration_tests/structs_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

@dataclass
class Mat:
mat: f64[2, 2]
mat: f64[2, 2] = empty((2, 2), dtype=float64)

@dataclass
class Vec:
vec: f64[2]
vec: f64[2] = empty(2, dtype=float64)

@dataclass
class MatVec:
mat: Mat = Mat([f64(0.0), f64(0.0)])
vec: Vec = Vec([f64(0.0), f64(0.0)])
mat: Mat = Mat()
vec: Vec = Vec()

def rotate(mat_vec: MatVec) -> f64[2]:
rotated_vec: f64[2] = empty(2, dtype=float64)
Expand Down
29 changes: 28 additions & 1 deletion integration_tests/structs_27.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from lpython import dataclass, i32
from lpython import dataclass, i32, u16, f32


@dataclass
Expand All @@ -7,6 +7,15 @@ class StringIO:
_0cursor : i32 = 10
_len : i32 = 1

@dataclass
class StringIONew:
_buf : str
_0cursor : i32 = i32(142)
_len : i32 = i32(2439)
_var1 : u16 = u16(23)
_var2 : f32 = f32(30.24)

#print("ok")

def test_issue_1928():
integer_asr : str = '(Integer 4 [])'
Expand Down Expand Up @@ -47,4 +56,22 @@ def test_issue_1928():
assert test_dude4._0cursor == 31


def test_issue_1981():
integer_asr : str = '(Integer 4 [])'
test_dude : StringIONew = StringIONew(integer_asr)
assert test_dude._buf == integer_asr
assert test_dude._len == 2439
assert test_dude._0cursor == 142
assert test_dude._var1 == u16(23)
assert abs(test_dude._var2 - f32(30.24)) < f32(1e-5)
test_dude._len = 13
test_dude._0cursor = 52
test_dude._var1 = u16(34)
assert test_dude._buf == integer_asr
assert test_dude._len == 13
assert test_dude._0cursor == 52
assert test_dude._var1 == u16(34)


test_issue_1981()
test_issue_1928()
45 changes: 45 additions & 0 deletions integration_tests/structs_32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from lpython import packed, dataclass, i32, InOut


@packed
@dataclass
class inner_struct:
a: i32


@packed
@dataclass
class outer_struct:
b: inner_struct = inner_struct(0)


def update_my_inner_struct(my_inner_struct: InOut[inner_struct]) -> None:
my_inner_struct.a = 99999


def update_my_outer_struct(my_outer_struct: InOut[outer_struct]) -> None:
my_outer_struct.b.a = 12345


def main() -> None:
my_outer_struct: outer_struct = outer_struct()
my_inner_struct: inner_struct = my_outer_struct.b

assert my_outer_struct.b.a == 0

my_outer_struct.b.a = 12345
assert my_outer_struct.b.a == 12345

my_outer_struct.b.a = 0
assert my_outer_struct.b.a == 0

update_my_outer_struct(my_outer_struct)
assert my_outer_struct.b.a == 12345

my_inner_struct.a = 1111
assert my_inner_struct.a == 1111

update_my_inner_struct(my_inner_struct)
assert my_inner_struct.a == 99999

main()
4 changes: 2 additions & 2 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ symbol
identifier original_name, access access)
| StructType(symbol_table symtab, identifier name, identifier* dependencies,
identifier* members, abi abi, access access, bool is_packed, bool is_abstract,
expr? alignment, symbol? parent)
call_arg* initializers, expr? alignment, symbol? parent)
| EnumType(symbol_table symtab, identifier name, identifier* dependencies,
identifier* members, abi abi, access access, enumtype enum_value_type,
ttype type, symbol? parent)
| UnionType(symbol_table symtab, identifier name, identifier* dependencies,
identifier* members, abi abi, access access, symbol? parent)
identifier* members, abi abi, access access, call_arg* initializers, symbol? parent)
| Variable(symbol_table parent_symtab, identifier name, identifier* dependencies,
intent intent, expr? symbolic_value, expr? value, storage_type storage,
ttype type, symbol? type_declaration,
Expand Down
7 changes: 4 additions & 3 deletions src/libasr/codegen/asr_to_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,12 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
void allocate_array_members_of_struct(ASR::StructType_t* der_type_t, std::string& sub,
std::string indent, std::string name) {
for( auto itr: der_type_t->m_symtab->get_scope() ) {
if( ASR::is_a<ASR::UnionType_t>(*itr.second) ||
ASR::is_a<ASR::StructType_t>(*itr.second) ) {
ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(itr.second);
if( ASR::is_a<ASR::UnionType_t>(*sym) ||
ASR::is_a<ASR::StructType_t>(*sym) ) {
continue ;
}
ASR::ttype_t* mem_type = ASRUtils::symbol_type(itr.second);
ASR::ttype_t* mem_type = ASRUtils::symbol_type(sym);
if( ASRUtils::is_character(*mem_type) ) {
sub += indent + name + "->" + itr.first + " = NULL;\n";
} else if( ASRUtils::is_array(mem_type) &&
Expand Down
5 changes: 5 additions & 0 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,11 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
last_expr_precedence = 2;
}

void visit_UnsignedIntegerConstant(const ASR::UnsignedIntegerConstant_t &x) {
src = std::to_string(x.m_n);
last_expr_precedence = 2;
}

void visit_RealConstant(const ASR::RealConstant_t &x) {
// TODO: remove extra spaces from the front of double_to_scientific result
src = double_to_scientific(x.m_r);
Expand Down
20 changes: 12 additions & 8 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Type* max_sized_type = nullptr;
size_t max_type_size = 0;
for( auto itr = scope.begin(); itr != scope.end(); itr++ ) {
ASR::Variable_t* member = ASR::down_cast<ASR::Variable_t>(itr->second);
ASR::Variable_t* member = ASR::down_cast<ASR::Variable_t>(ASRUtils::symbol_get_past_external(itr->second));
llvm::Type* llvm_mem_type = getMemberType(member->m_type, member);
size_t type_size = data_layout.getTypeAllocSize(llvm_mem_type);
if( max_type_size < type_size ) {
Expand Down Expand Up @@ -3316,14 +3316,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
ASRUtils::symbol_get_past_external(struct_t->m_derived_type));
std::string struct_type_name = struct_type_t->m_name;
for( auto item: struct_type_t->m_symtab->get_scope() ) {
if( ASR::is_a<ASR::ClassProcedure_t>(*item.second) ||
ASR::is_a<ASR::GenericProcedure_t>(*item.second) ||
ASR::is_a<ASR::UnionType_t>(*item.second) ||
ASR::is_a<ASR::StructType_t>(*item.second) ||
ASR::is_a<ASR::CustomOperator_t>(*item.second) ) {
ASR::symbol_t *sym = ASRUtils::symbol_get_past_external(item.second);
if (name2memidx[struct_type_name].find(item.first) == name2memidx[struct_type_name].end()) {
continue;
}
if( ASR::is_a<ASR::ClassProcedure_t>(*sym) ||
ASR::is_a<ASR::GenericProcedure_t>(*sym) ||
ASR::is_a<ASR::UnionType_t>(*sym) ||
ASR::is_a<ASR::StructType_t>(*sym) ||
ASR::is_a<ASR::CustomOperator_t>(*sym) ) {
continue ;
}
ASR::ttype_t* symbol_type = ASRUtils::symbol_type(item.second);
ASR::ttype_t* symbol_type = ASRUtils::symbol_type(sym);
int idx = name2memidx[struct_type_name][item.first];
llvm::Value* ptr_member = llvm_utils->create_gep(ptr, idx);
ASR::Variable_t* v = nullptr;
Expand Down Expand Up @@ -8569,7 +8573,7 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,
pass_manager.apply_passes(al, &asr, pass_options, diagnostics);

// Uncomment for debugging the ASR after the transformation
// std::cout << LFortran::pickle(asr, false, false, false) << std::endl;
// std::cout << LCompilers::LPython::pickle(asr, true, true, false) << std::endl;

try {
v.visit_asr((ASR::asr_t&)asr);
Expand Down
18 changes: 10 additions & 8 deletions src/libasr/codegen/c_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -843,26 +843,28 @@ class CCPPDSUtils {
+ struct_type_str + "* dest)";
func_decls += "inline " + signature + ";\n";
generated_code += indent + signature + " {\n";
for( auto item: struct_type_t->m_symtab->get_scope() ) {
ASR::ttype_t* member_type_asr = ASRUtils::symbol_type(item.second);
for(size_t i=0; i < struct_type_t->n_members; i++) {
std::string mem_name = std::string(struct_type_t->m_members[i]);
ASR::symbol_t* member = struct_type_t->m_symtab->get_symbol(mem_name);
ASR::ttype_t* member_type_asr = ASRUtils::symbol_type(member);
if( CUtils::is_non_primitive_DT(member_type_asr) ||
ASR::is_a<ASR::Character_t>(*member_type_asr) ) {
generated_code += indent + tab + get_deepcopy(member_type_asr, "&(src->" + item.first + ")",
"&(dest->" + item.first + ")") + ";\n";
generated_code += indent + tab + get_deepcopy(member_type_asr, "&(src->" + mem_name + ")",
"&(dest->" + mem_name + ")") + ";\n";
} else if( ASRUtils::is_array(member_type_asr) ) {
ASR::dimension_t* m_dims = nullptr;
size_t n_dims = ASRUtils::extract_dimensions_from_ttype(member_type_asr, m_dims);
if( ASRUtils::is_fixed_size_array(m_dims, n_dims) ) {
std::string array_size = std::to_string(ASRUtils::get_fixed_size_of_array(m_dims, n_dims));
array_size += "*sizeof(" + CUtils::get_c_type_from_ttype_t(member_type_asr) + ")";
generated_code += indent + tab + "memcpy(dest->" + item.first + ", src->" + item.first +
generated_code += indent + tab + "memcpy(dest->" + mem_name + ", src->" + mem_name +
", " + array_size + ");\n";
} else {
generated_code += indent + tab + get_deepcopy(member_type_asr, "src->" + item.first,
"dest->" + item.first) + ";\n";
generated_code += indent + tab + get_deepcopy(member_type_asr, "src->" + mem_name,
"dest->" + mem_name) + ";\n";
}
} else {
generated_code += indent + tab + "dest->" + item.first + " = " + " src->" + item.first + ";\n";
generated_code += indent + tab + "dest->" + mem_name + " = " + " src->" + mem_name + ";\n";
}
}
generated_code += indent + "}\n\n";
Expand Down
Loading

0 comments on commit e642072

Please sign in to comment.