Skip to content

Commit

Permalink
Add dict.keys and dict.values (lcompilers#2023)
Browse files Browse the repository at this point in the history
  • Loading branch information
kabra1110 authored Jul 25, 2023
1 parent a183feb commit 634177d
Show file tree
Hide file tree
Showing 8 changed files with 451 additions and 2 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ RUN(NAME test_dict_12 LABELS cpython llvm c)
RUN(NAME test_dict_13 LABELS cpython llvm c)
RUN(NAME test_dict_bool LABELS cpython llvm)
RUN(NAME test_dict_increment LABELS cpython llvm)
RUN(NAME test_dict_keys_values LABELS cpython llvm)
RUN(NAME test_set_len LABELS cpython llvm)
RUN(NAME test_set_add LABELS cpython llvm)
RUN(NAME test_set_remove LABELS cpython llvm)
Expand Down
54 changes: 54 additions & 0 deletions integration_tests/test_dict_keys_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from lpython import i32, f64

def test_dict_keys_values():
d1: dict[i32, i32] = {}
k1: list[i32]
k1_copy: list[i32] = []
v1: list[i32]
v1_copy: list[i32] = []
i: i32
j: i32
s: str
key_count: i32

for i in range(105, 115):
d1[i] = i + 1
k1 = d1.keys()
for i in k1:
k1_copy.append(i)
v1 = d1.values()
for i in v1:
v1_copy.append(i)
assert len(k1) == 10
for i in range(105, 115):
key_count = 0
for j in range(len(k1)):
if k1_copy[j] == i:
key_count += 1
assert v1_copy[j] == d1[i]
assert key_count == 1

d2: dict[str, str] = {}
k2: list[str]
k2_copy: list[str] = []
v2: list[str]
v2_copy: list[str] = []

for i in range(105, 115):
d2[str(i)] = str(i + 1)
k2 = d2.keys()
for s in k2:
k2_copy.append(s)
v2 = d2.values()
for s in v2:
v2_copy.append(s)
assert len(k2) == 10
for i in range(105, 115):
key_count = 0
for j in range(len(k2)):
if k2_copy[j] == str(i):
key_count += 1
assert v2_copy[j] == d2[str(i)]
assert key_count == 1

test_dict_keys_values()
12 changes: 11 additions & 1 deletion src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,24 @@ static inline ASR::abiType symbol_abi(const ASR::symbol_t *f)
return ASR::abiType::Source;
}

static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type) {
static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type, int overload=0) {
switch( asr_type->type ) {
case ASR::ttypeType::List: {
return ASR::down_cast<ASR::List_t>(asr_type)->m_type;
}
case ASR::ttypeType::Set: {
return ASR::down_cast<ASR::Set_t>(asr_type)->m_type;
}
case ASR::ttypeType::Dict: {
switch( overload ) {
case 0:
return ASR::down_cast<ASR::Dict_t>(asr_type)->m_key_type;
case 1:
return ASR::down_cast<ASR::Dict_t>(asr_type)->m_value_type;
default:
return asr_type;
}
}
case ASR::ttypeType::Enum: {
ASR::Enum_t* enum_asr = ASR::down_cast<ASR::Enum_t>(asr_type);
ASR::EnumType_t* enum_type = ASR::down_cast<ASR::EnumType_t>(enum_asr->m_enum_type);
Expand Down
51 changes: 51 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,49 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx);
}

void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value) {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
ASRUtils::expr_type(m_arg));
ASR::ttype_t* el_type = key_or_value == 0 ?
dict_type->m_key_type : dict_type->m_value_type;

int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*m_arg);
llvm::Value* pdict = tmp;

ptr_loads = ptr_loads_copy;

bool is_array_type_local = false, is_malloc_array_type_local = false;
bool is_list_local = false;
ASR::dimension_t* m_dims_local = nullptr;
int n_dims_local = -1, a_kind_local = -1;
llvm::Type* llvm_el_type = llvm_utils->get_type_from_ttype_t(el_type, nullptr,
ASR::storage_typeType::Default, is_array_type_local,
is_malloc_array_type_local, is_list_local, m_dims_local,
n_dims_local, a_kind_local, module.get());
std::string type_code = ASRUtils::get_type_code(el_type);
int32_t type_size = -1;
if( ASR::is_a<ASR::Character_t>(*el_type) ||
LLVM::is_llvm_struct(el_type) ||
ASR::is_a<ASR::Complex_t>(*el_type) ) {
llvm::DataLayout data_layout(module.get());
type_size = data_layout.getTypeAllocSize(llvm_el_type);
} else {
type_size = ASRUtils::extract_kind_from_ttype_t(el_type);
}
llvm::Type* el_list_type = list_api->get_list_type(llvm_el_type, type_code, type_size);
llvm::Value* el_list = builder->CreateAlloca(el_list_type, nullptr, key_or_value == 0 ?
"keys_list" : "values_list");
list_api->list_init(type_code, el_list, *module, 0, 0);

llvm_utils->set_dict_api(dict_type);
llvm_utils->dict_api->get_elements_list(pdict, el_list, dict_type->m_key_type,
dict_type->m_value_type, *module,
name2memidx, key_or_value);
tmp = el_list;
}

void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
int64_t ptr_loads_copy = ptr_loads;
Expand Down Expand Up @@ -1755,6 +1798,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
break;
}
case ASRUtils::IntrinsicFunctions::DictKeys: {
generate_DictElems(x.m_args[0], 0);
break;
}
case ASRUtils::IntrinsicFunctions::DictValues: {
generate_DictElems(x.m_args[0], 1);
break;
}
case ASRUtils::IntrinsicFunctions::SetAdd: {
generate_SetAdd(x.m_args[0], x.m_args[1]);
break;
Expand Down
178 changes: 178 additions & 0 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,12 @@ namespace LCompilers {
list_api->list_deepcopy(src, dest, list_type, module, name2memidx);
break ;
}
case ASR::ttypeType::Dict: {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(asr_type);
// set dict api here?
dict_api->dict_deepcopy(src, dest, dict_type, module, name2memidx);
break ;
}
case ASR::ttypeType::Struct: {
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(asr_type);
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
Expand Down Expand Up @@ -3865,6 +3871,178 @@ namespace LCompilers {
return LLVM::CreateLoad(*builder, value_ptr);
}

void LLVMDict::get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
ASR::ttype_t* value_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value) {

/**
* C++ equivalent:
*
* // key_or_value = 0 for keys, 1 for values
*
* idx = 0;
*
* while( capacity > idx ) {
* el = key_or_value_list[idx];
* key_mask_value = key_mask[idx];
*
* is_key_skip = key_mask_value == 3; // tombstone
* is_key_set = key_mask_value != 0;
* add_el = is_key_set && !is_key_skip;
* if( add_el ) {
* elements_list.append(el);
* }
*
* idx++;
* }
*
*/

llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
llvm::Value* el_list = key_or_value == 0 ? get_key_list(dict) : get_value_list(dict);
ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type;
if( !are_iterators_set ) {
idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
}
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 0)), idx_ptr);

llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value *cond = builder->CreateICmpSGT(capacity, LLVM::CreateLoad(*builder, idx_ptr));
builder->CreateCondBr(cond, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr);
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
llvm_utils->create_ptr_gep(key_mask, idx));
llvm::Value* is_key_skip = builder->CreateICmpEQ(key_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)));
llvm::Value* is_key_set = builder->CreateICmpNE(key_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));

llvm::Value* add_el = builder->CreateAnd(is_key_set,
builder->CreateNot(is_key_skip));
llvm_utils->create_if_else(add_el, [&]() {
llvm::Value* el = llvm_utils->list_api->read_item(el_list, idx,
false, module, LLVM::is_llvm_struct(el_asr_type));
llvm_utils->list_api->append(elements_list, el,
el_asr_type, &module, name2memidx);
}, [=]() {
});

idx = builder->CreateAdd(idx, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, idx, idx_ptr);
}

builder->CreateBr(loophead);

// end
llvm_utils->start_new_block(loopend);
}

void LLVMDictSeparateChaining::get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
ASR::ttype_t* value_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value) {
if( !are_iterators_set ) {
idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
chain_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
}
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 0)), idx_ptr);

llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict));
llvm::Type* kv_pair_type = get_key_value_pair_type(key_asr_type, value_asr_type);
ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type;
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value *cond = builder->CreateICmpSGT(
capacity,
LLVM::CreateLoad(*builder, idx_ptr));
builder->CreateCondBr(cond, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr);
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
llvm_utils->create_ptr_gep(key_mask, idx));
llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));

llvm_utils->create_if_else(is_key_set, [&]() {
llvm::Value* dict_i = llvm_utils->create_ptr_gep(key_value_pairs, idx);
llvm::Value* kv_ll_i8 = builder->CreateBitCast(dict_i, llvm::Type::getInt8PtrTy(context));
LLVM::CreateStore(*builder, kv_ll_i8, chain_itr);

llvm::BasicBlock *loop2head = llvm::BasicBlock::Create(context, "loop2.head");
llvm::BasicBlock *loop2body = llvm::BasicBlock::Create(context, "loop2.body");
llvm::BasicBlock *loop2end = llvm::BasicBlock::Create(context, "loop2.end");

// head
llvm_utils->start_new_block(loop2head);
{
llvm::Value *cond = builder->CreateICmpNE(
LLVM::CreateLoad(*builder, chain_itr),
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))
);
builder->CreateCondBr(cond, loop2body, loop2end);
}

// body
llvm_utils->start_new_block(loop2body);
{
llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_pair_type->getPointerTo());
llvm::Value* kv_el = llvm_utils->create_gep(kv_struct, key_or_value);
if( !LLVM::is_llvm_struct(el_asr_type) ) {
kv_el = LLVM::CreateLoad(*builder, kv_el);
}
llvm_utils->list_api->append(elements_list, kv_el,
el_asr_type, &module, name2memidx);
llvm::Value* next_kv_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 2));
LLVM::CreateStore(*builder, next_kv_struct, chain_itr);
}

builder->CreateBr(loop2head);

// end
llvm_utils->start_new_block(loop2end);
}, [=]() {
});
llvm::Value* tmp = builder->CreateAdd(idx,
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, tmp, idx_ptr);
}

builder->CreateBr(loophead);

// end
llvm_utils->start_new_block(loopend);
}

llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
bool enable_bounds_checking,
llvm::Module& module, bool get_pointer) {
Expand Down
19 changes: 19 additions & 0 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,13 @@ namespace LCompilers {
virtual
void set_is_dict_present(bool value);

virtual
void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
ASR::ttype_t* value_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value) = 0;

virtual ~LLVMDictInterface() = 0;

};
Expand Down Expand Up @@ -713,6 +720,12 @@ namespace LCompilers {

llvm::Value* len(llvm::Value* dict);

void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
ASR::ttype_t* value_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value);

virtual ~LLVMDict();
};

Expand Down Expand Up @@ -860,6 +873,12 @@ namespace LCompilers {

llvm::Value* len(llvm::Value* dict);

void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
ASR::ttype_t* value_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value);

virtual ~LLVMDictSeparateChaining();

};
Expand Down
Loading

0 comments on commit 634177d

Please sign in to comment.