Skip to content

Commit

Permalink
fix dict keys and values for LP
Browse files Browse the repository at this point in the history
  • Loading branch information
kabra1110 committed Jul 8, 2023
1 parent 9ad9c8b commit f28eaf0
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 6 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ RUN(NAME test_dict_12 LABELS cpython llvm c)
RUN(NAME test_dict_13 LABELS cpython llvm c)
RUN(NAME test_dict_bool LABELS cpython llvm)
RUN(NAME test_dict_increment LABELS cpython llvm)
RUN(NAME test_dict_keys_values LABELS cpython llvm)
RUN(NAME test_for_loop LABELS cpython llvm c)
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
43 changes: 43 additions & 0 deletions integration_tests/test_dict_keys_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from lpython import i32, f64

def test_dict_keys_values():
d1: dict[i32, i32] = {}
d2: dict[tuple[i32, i32], tuple[i32, tuple[str, f64]]] = {}
k1: list[i32]
k2: list[tuple[i32, i32]]
v1: list[i32]
v2: list[tuple[i32, tuple[str, f64]]]
i: i32
j: i32
key_count: i32
s: str

for i in range(105, 115):
d1[i] = i + 1
k1 = d1.keys()
v1 = d1.values()
assert len(k1) == 10
for i in range(105, 115):
key_count = 0
for j in range(len(k1)):
if k1[j] == i:
key_count += 1
assert v1[j] == d1[i]
assert key_count == 1

s = 'a'
for i in range(10):
d2[(i, i + 1)] = (i, (s, f64(i * i)))
s += 'a'
k2 = d2.keys()
v2 = d2.values()
assert len(k2) == 10
for i in range(10):
key_count = 0
for j in range(len(k2)):
if k2[j] == (i, i + 1):
key_count += 1
assert v2[j] == d2[k2[j]]
assert key_count == 1

test_dict_keys_values()
43 changes: 40 additions & 3 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2072,18 +2072,51 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx);
}

void generate_DictKeys(ASR::expr_t* m_arg) {
void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value, const Location &loc) {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(
ASRUtils::expr_type(m_arg));
ASR::ttype_t* el_type = key_or_value == 0 ?
dict_type->m_key_type : dict_type->m_value_type;

int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*m_arg);
llvm::Value* pdict = tmp;

set_dict_api(dict_type);
if(llvm_utils->dict_api == dict_api_sc.get()) {
throw CodeGenError("dict.keys and dict.values are only implemented "
"for linear probing for now", loc);
}
ptr_loads = ptr_loads_copy;
tmp = llvm_utils->dict_api->get_key_list(pdict);

bool is_array_type_local = false, is_malloc_array_type_local = false;
bool is_list_local = false;
ASR::dimension_t* m_dims_local = nullptr;
int n_dims_local = -1, a_kind_local = -1;
llvm::Type* llvm_el_type = get_type_from_ttype_t(el_type,
nullptr,
ASR::storage_typeType::Default, is_array_type_local,
is_malloc_array_type_local, is_list_local, m_dims_local,
n_dims_local, a_kind_local);
std::string type_code = ASRUtils::get_type_code(el_type);
int32_t type_size = -1;
if( ASR::is_a<ASR::Character_t>(*el_type) ||
LLVM::is_llvm_struct(el_type) ||
ASR::is_a<ASR::Complex_t>(*el_type) ) {
llvm::DataLayout data_layout(module.get());
type_size = data_layout.getTypeAllocSize(llvm_el_type);
} else {
type_size = ASRUtils::extract_kind_from_ttype_t(el_type);
}
llvm::Type* el_list_type = list_api->get_list_type(llvm_el_type, type_code, type_size);
llvm::Value* el_list = builder->CreateAlloca(el_list_type, nullptr, key_or_value == 0 ?
"keys_list" : "values_list");
list_api->list_init(type_code, el_list, *module, 0, 0);

llvm_utils->dict_api->get_elements_list(pdict, el_list, el_type, *module,
name2memidx, key_or_value);
tmp = el_list;
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
Expand Down Expand Up @@ -2130,7 +2163,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
break;
}
case ASRUtils::IntrinsicFunctions::DictKeys: {
generate_DictKeys(x.m_args[0]);
generate_DictElems(x.m_args[0], 0, x.base.base.loc);
break;
}
case ASRUtils::IntrinsicFunctions::DictValues: {
generate_DictElems(x.m_args[0], 1, x.base.base.loc);
break;
}
case ASRUtils::IntrinsicFunctions::Exp: {
Expand Down
89 changes: 89 additions & 0 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ namespace LCompilers {
list_api->list_deepcopy(src, dest, list_type, module, name2memidx);
break ;
}
case ASR::ttypeType::Dict: {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(asr_type);
// set dict api here?
dict_api->dict_deepcopy(src, dest, dict_type, module, name2memidx);
break ;
}
case ASR::ttypeType::Struct: {
ASR::Struct_t* struct_t = ASR::down_cast<ASR::Struct_t>(asr_type);
ASR::StructType_t* struct_type_t = ASR::down_cast<ASR::StructType_t>(
Expand Down Expand Up @@ -2469,6 +2475,89 @@ namespace LCompilers {
return LLVM::CreateLoad(*builder, value_ptr);
}

void LLVMDict::get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value) {

/**
* C++ equivalent:
*
* idx = 0;
*
* while( capacity > idx ) {
* el = key_or_value_list[idx];
* key_mask_value = key_mask[idx];
*
* is_key_skip = key_mask_value == 3; // tombstone
* is_key_set = key_mask_value != 0;
* add_el = is_key_set && !is_key_skip;
* if( add_el ) {
* elements_list.append(el);
* }
*
* idx++;
* }
*
*/

llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
llvm::Value* el_list = key_or_value == 0 ? get_key_list(dict) : get_value_list(dict);
if( !are_iterators_set ) {
idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
}
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 0)), idx_ptr);

llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value *cond = builder->CreateICmpSGT(capacity, LLVM::CreateLoad(*builder, idx_ptr));
builder->CreateCondBr(cond, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr);
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
llvm_utils->create_ptr_gep(key_mask, idx));
llvm::Value* is_key_skip = builder->CreateICmpEQ(key_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)));
llvm::Value* is_key_set = builder->CreateICmpNE(key_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));

llvm::Value* add_el = builder->CreateAnd(is_key_set,
builder->CreateNot(is_key_skip));
llvm_utils->create_if_else(add_el, [&]() {
llvm::Value* el = llvm_utils->list_api->read_item(el_list, idx,
false, module, LLVM::is_llvm_struct(el_asr_type));
llvm_utils->list_api->append(elements_list, el,
el_asr_type, &module, name2memidx);
}, [=]() {
});

idx = builder->CreateAdd(idx, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, idx, idx_ptr);
}

builder->CreateBr(loophead);

// end
llvm_utils->start_new_block(loopend);
}

void LLVMDictSeparateChaining::get_elements_list(llvm::Value* /*dict*/,
llvm::Value* /*elements_list*/, ASR::ttype_t* /*el_asr_type*/, llvm::Module& /*module*/,
std::map<std::string, std::map<std::string, int>>& /*name2memidx*/,
bool /*key_or_value*/) {}

llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
bool enable_bounds_checking,
llvm::Module& module, bool get_pointer) {
Expand Down
16 changes: 16 additions & 0 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,12 @@ namespace LCompilers {
virtual
void set_is_dict_present(bool value);

virtual
void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value) = 0;

virtual ~LLVMDictInterface() = 0;

};
Expand Down Expand Up @@ -555,6 +561,11 @@ namespace LCompilers {

llvm::Value* len(llvm::Value* dict);

void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value);

virtual ~LLVMDict();
};

Expand Down Expand Up @@ -702,6 +713,11 @@ namespace LCompilers {

llvm::Value* len(llvm::Value* dict);

void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module,
std::map<std::string, std::map<std::string, int>>& name2memidx,
bool key_or_value);

virtual ~LLVMDictSeparateChaining();

};
Expand Down
57 changes: 55 additions & 2 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ enum class IntrinsicFunctions : int64_t {
ListReverse,
ListPop,
DictKeys,
DictValues,
SymbolicSymbol,
SymbolicAdd,
SymbolicSub,
Expand Down Expand Up @@ -1148,8 +1149,8 @@ static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnost
ASRUtils::require_impl(ASR::is_a<ASR::Dict_t>(*ASRUtils::expr_type(x.m_args[0])),
"Argument to dict.keys must be of dict type",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASRUtils::check_equal_type(
ASRUtils::get_contained_type(x.m_type),
ASRUtils::require_impl(ASR::is_a<ASR::List_t>(*x.m_type) &&
ASRUtils::check_equal_type(ASRUtils::get_contained_type(x.m_type),
ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]), 0)),
"Return type of dict.keys must be of list of dict key element type",
x.base.base.loc, diagnostics);
Expand Down Expand Up @@ -1186,6 +1187,52 @@ static inline ASR::asr_t* create_DictKeys(Allocator& al, const Location& loc,

} // namespace DictKeys

namespace DictValues {

static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 1, "Call to dict.values must have no argument",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASR::is_a<ASR::Dict_t>(*ASRUtils::expr_type(x.m_args[0])),
"Argument to dict.values must be of dict type",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASR::is_a<ASR::List_t>(*x.m_type) &&
ASRUtils::check_equal_type(ASRUtils::get_contained_type(x.m_type),
ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]), 1)),
"Return type of dict.values must be of list of dict value element type",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t *eval_dict_values(Allocator &/*al*/,
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
// TODO: To be implemented for DictConstant expression
return nullptr;
}

static inline ASR::asr_t* create_DictValues(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
if (args.size() != 1) {
err("Call to dict.values must have no argument", loc);
}

ASR::expr_t* dict_expr = args[0];
ASR::ttype_t *type = ASRUtils::expr_type(dict_expr);
ASR::ttype_t *dict_values_type = ASR::down_cast<ASR::Dict_t>(type)->m_value_type;

Vec<ASR::expr_t*> arg_values;
arg_values.reserve(al, args.size());
for( size_t i = 0; i < args.size(); i++ ) {
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
}
ASR::expr_t* compile_time_value = eval_dict_values(al, loc, arg_values);
ASR::ttype_t *to_type = List(dict_values_type);
return ASR::make_IntrinsicFunction_t(al, loc,
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictValues),
args.p, args.size(), 0, to_type, compile_time_value);
}

} // namespace DictValues

namespace Any {

static inline void verify_array(ASR::expr_t* array, ASR::ttype_t* return_type,
Expand Down Expand Up @@ -2261,6 +2308,8 @@ namespace IntrinsicFunctionRegistry {
{nullptr, &ListReverse::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictKeys),
{nullptr, &DictKeys::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictValues),
{nullptr, &DictValues::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol),
{nullptr, &SymbolicSymbol::verify_args}},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd),
Expand Down Expand Up @@ -2317,6 +2366,8 @@ namespace IntrinsicFunctionRegistry {
"list.pop"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictKeys),
"dict.keys"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::DictValues),
"dict.values"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol),
"Symbol"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd),
Expand Down Expand Up @@ -2363,6 +2414,7 @@ namespace IntrinsicFunctionRegistry {
{"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}},
{"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}},
{"dict.keys", {&DictKeys::create_DictKeys, &DictKeys::eval_dict_keys}},
{"dict.values", {&DictValues::create_DictValues, &DictValues::eval_dict_values}},
{"Symbol", {&SymbolicSymbol::create_SymbolicSymbol, &SymbolicSymbol::eval_SymbolicSymbol}},
{"SymbolicAdd", {&SymbolicAdd::create_SymbolicAdd, &SymbolicAdd::eval_SymbolicAdd}},
{"SymbolicSub", {&SymbolicSub::create_SymbolicSub, &SymbolicSub::eval_SymbolicSub}},
Expand Down Expand Up @@ -2478,6 +2530,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(ListReverse)
INTRINSIC_NAME_CASE(ListPop)
INTRINSIC_NAME_CASE(DictKeys)
INTRINSIC_NAME_CASE(DictValues)
INTRINSIC_NAME_CASE(SymbolicSymbol)
INTRINSIC_NAME_CASE(SymbolicAdd)
INTRINSIC_NAME_CASE(SymbolicSub)
Expand Down
17 changes: 16 additions & 1 deletion src/lpython/semantics/python_attribute_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ struct AttributeHandler {
{"set@remove", &eval_set_remove},
{"dict@get", &eval_dict_get},
{"dict@pop", &eval_dict_pop},
{"dict@keys", &eval_dict_keys}
{"dict@keys", &eval_dict_keys},
{"dict@values", &eval_dict_values}
};

modify_attr_set = {"list@append", "list@remove",
Expand Down Expand Up @@ -403,6 +404,20 @@ struct AttributeHandler {
{ throw SemanticError(msg, loc); });
}

static ASR::asr_t* eval_dict_values(ASR::expr_t *s, Allocator &al, const Location &loc,
Vec<ASR::expr_t*> &args, diag::Diagnostics &/*diag*/) {
Vec<ASR::expr_t*> args_with_dict;
args_with_dict.reserve(al, args.size() + 1);
args_with_dict.push_back(al, s);
for(size_t i = 0; i < args.size(); i++) {
args_with_dict.push_back(al, args[i]);
}
ASRUtils::create_intrinsic_function create_function =
ASRUtils::IntrinsicFunctionRegistry::get_create_function("dict.values");
return create_function(al, loc, args_with_dict, [&](const std::string &msg, const Location &loc)
{ throw SemanticError(msg, loc); });
}

}; // AttributeHandler

} // namespace LCompilers::LPython
Expand Down

0 comments on commit f28eaf0

Please sign in to comment.