Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set data structure #2122

Merged
merged 9 commits into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,9 @@ RUN(NAME test_dict_12 LABELS cpython llvm c)
RUN(NAME test_dict_13 LABELS cpython llvm c)
RUN(NAME test_dict_bool LABELS cpython llvm)
RUN(NAME test_dict_increment LABELS cpython llvm)
RUN(NAME test_set_len LABELS cpython llvm)
RUN(NAME test_set_add LABELS cpython llvm)
RUN(NAME test_set_remove LABELS cpython llvm)
RUN(NAME test_for_loop LABELS cpython llvm c)
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
34 changes: 34 additions & 0 deletions integration_tests/test_set_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from lpython import i32

def test_set_add():
s1: set[i32]
s2: set[tuple[i32, tuple[i32, i32], str]]
s3: set[str]
st1: str
i: i32
j: i32

s1 = {0}
s2 = {(0, (1, 2), 'a')}
for i in range(20):
j = i % 10
s1.add(j)
s2.add((j, (j + 1, j + 2), 'a'))
assert len(s1) == len(s2)
if i < 10:
assert len(s1) == i + 1
else:
assert len(s1) == 10

st1 = 'a'
s3 = {st1}
for i in range(20):
s3.add(st1)
if i < 10:
if i > 0:
assert len(s3) == i
st1 += 'a'
else:
assert len(s3) == 10

test_set_add()
8 changes: 8 additions & 0 deletions integration_tests/test_set_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from lpython import i32

def test_set():
s: set[i32]
s = {1, 2, 22, 2, -1, 1}
assert len(s) == 4

test_set()
47 changes: 47 additions & 0 deletions integration_tests/test_set_remove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from lpython import i32

def test_set_add():
s1: set[i32]
s2: set[tuple[i32, tuple[i32, i32], str]]
s3: set[str]
st1: str
i: i32
j: i32
k: i32

for k in range(2):
s1 = {0}
s2 = {(0, (1, 2), 'a')}
for i in range(20):
j = i % 10
s1.add(j)
s2.add((j, (j + 1, j + 2), 'a'))

for i in range(10):
s1.remove(i)
s2.remove((i, (i + 1, i + 2), 'a'))
# assert len(s1) == 10 - 1 - i
# assert len(s1) == len(s2)

st1 = 'a'
s3 = {st1}
for i in range(20):
s3.add(st1)
if i < 10:
if i > 0:
st1 += 'a'

st1 = 'a'
for i in range(10):
s3.remove(st1)
assert len(s3) == 10 - 1 - i
if i < 10:
st1 += 'a'

for i in range(20):
s1.add(i)
if i % 2 == 0:
s1.remove(i)
assert len(s1) == (i + 1) // 2

test_set_add()
2 changes: 0 additions & 2 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ stmt
| SelectType(expr selector, type_stmt* body, stmt* default)
| CPtrToPointer(expr cptr, expr ptr, expr? shape, expr? lower_bounds)
| BlockCall(int label, symbol m)
| SetInsert(expr a, expr ele)
| SetRemove(expr a, expr ele)
| ListInsert(expr a, expr pos, expr ele)
| ListRemove(expr a, expr ele)
| ListClear(expr a)
Expand Down
86 changes: 86 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
std::unique_ptr<LLVMTuple> tuple_api;
std::unique_ptr<LLVMDictInterface> dict_api_lp;
std::unique_ptr<LLVMDictInterface> dict_api_sc;
std::unique_ptr<LLVMSetInterface> set_api; // linear probing
std::unique_ptr<LLVMArrUtils::Descriptor> arr_descr;

ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, std::string infile,
Expand All @@ -199,13 +200,15 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tuple_api(std::make_unique<LLVMTuple>(context, llvm_utils.get(), builder.get())),
dict_api_lp(std::make_unique<LLVMDictOptimizedLinearProbing>(context, llvm_utils.get(), builder.get())),
dict_api_sc(std::make_unique<LLVMDictSeparateChaining>(context, llvm_utils.get(), builder.get())),
set_api(std::make_unique<LLVMSetLinearProbing>(context, llvm_utils.get(), builder.get())),
arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context,
builder.get(), llvm_utils.get(),
LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor))
{
llvm_utils->tuple_api = tuple_api.get();
llvm_utils->list_api = list_api.get();
llvm_utils->dict_api = nullptr;
llvm_utils->set_api = set_api.get();
llvm_utils->arr_api = arr_descr.get();
llvm_utils->dict_api_lp = dict_api_lp.get();
llvm_utils->dict_api_sc = dict_api_sc.get();
Expand Down Expand Up @@ -1149,6 +1152,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = const_dict;
}

void visit_SetConstant(const ASR::SetConstant_t& x) {
llvm::Type* const_set_type = llvm_utils->get_set_type(x.m_type, module.get());
llvm::Value* const_set = builder->CreateAlloca(const_set_type, nullptr, "const_set");
ASR::Set_t* x_set = ASR::down_cast<ASR::Set_t>(x.m_type);
std::string el_type_code = ASRUtils::get_type_code(x_set->m_type);
llvm_utils->set_api->set_init(el_type_code, const_set, module.get(), x.n_elements);
int64_t ptr_loads_el = !LLVM::is_llvm_struct(x_set->m_type);
int64_t ptr_loads_copy = ptr_loads;
for( size_t i = 0; i < x.n_elements; i++ ) {
ptr_loads = ptr_loads_el;
visit_expr_wrapper(x.m_elements[i], true);
llvm::Value* element = tmp;
llvm_utils->set_api->write_item(const_set, element, module.get(),
x_set->m_type, name2memidx);
}
ptr_loads = ptr_loads_copy;
tmp = const_set;
}

void visit_TupleConstant(const ASR::TupleConstant_t& x) {
ASR::Tuple_t* tuple_type = ASR::down_cast<ASR::Tuple_t>(x.m_type);
std::string type_code = ASRUtils::get_type_code(tuple_type->m_type,
Expand Down Expand Up @@ -1487,6 +1509,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = llvm_utils->dict_api->len(pdict);
}

void visit_SetLen(const ASR::SetLen_t& x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
return ;
}

int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_arg);
ptr_loads = ptr_loads_copy;
llvm::Value* pset = tmp;
tmp = llvm_utils->set_api->len(pset);
}

void visit_ListInsert(const ASR::ListInsert_t& x) {
ASR::List_t* asr_list = ASR::down_cast<ASR::List_t>(
ASRUtils::expr_type(x.m_a));
Expand Down Expand Up @@ -1648,6 +1684,34 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx);
}

void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*m_arg);
llvm::Value* pset = tmp;

ptr_loads = 2;
this->visit_expr_wrapper(m_ele, true);
ptr_loads = ptr_loads_copy;
llvm::Value *el = tmp;
set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
}

void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*m_arg);
llvm::Value* pset = tmp;

ptr_loads = 2;
this->visit_expr_wrapper(m_ele, true);
ptr_loads = ptr_loads_copy;
llvm::Value *el = tmp;
set_api->remove_item(pset, el, *module, asr_el_type);
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
switch (static_cast<ASRUtils::IntrinsicFunctions>(x.m_intrinsic_id)) {
case ASRUtils::IntrinsicFunctions::ListIndex: {
Expand Down Expand Up @@ -1691,6 +1755,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
break;
}
case ASRUtils::IntrinsicFunctions::SetAdd: {
generate_SetAdd(x.m_args[0], x.m_args[1]);
break;
}
case ASRUtils::IntrinsicFunctions::SetRemove: {
generate_SetRemove(x.m_args[0], x.m_args[1]);
break;
}
case ASRUtils::IntrinsicFunctions::Exp: {
switch (x.m_overload_id) {
case 0: {
Expand Down Expand Up @@ -3945,6 +4017,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
bool is_value_tuple = ASR::is_a<ASR::Tuple_t>(*asr_value_type);
bool is_target_dict = ASR::is_a<ASR::Dict_t>(*asr_target_type);
bool is_value_dict = ASR::is_a<ASR::Dict_t>(*asr_value_type);
bool is_target_set = ASR::is_a<ASR::Set_t>(*asr_target_type);
bool is_value_set = ASR::is_a<ASR::Set_t>(*asr_value_type);
bool is_target_struct = ASR::is_a<ASR::Struct_t>(*asr_target_type);
bool is_value_struct = ASR::is_a<ASR::Struct_t>(*asr_value_type);
if (ASR::is_a<ASR::StringSection_t>(*x.m_target)) {
Expand Down Expand Up @@ -4034,6 +4108,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm_utils->dict_api->dict_deepcopy(value_dict, target_dict,
value_dict_type, module.get(), name2memidx);
return ;
} else if( is_target_set && is_value_set ) {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_value);
llvm::Value* value_set = tmp;
this->visit_expr(*x.m_target);
llvm::Value* target_set = tmp;
ptr_loads = ptr_loads_copy;
ASR::Set_t* value_set_type = ASR::down_cast<ASR::Set_t>(asr_value_type);
llvm_utils->set_api->set_deepcopy(value_set, target_set,
value_set_type, module.get(), name2memidx);
return ;
} else if( is_target_struct && is_value_struct ) {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
Expand Down
Loading