Skip to content

Commit

Permalink
Add list.count and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
virendrakabra14 committed Apr 8, 2023
1 parent b0027d5 commit ffde7ed
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 0 deletions.
33 changes: 33 additions & 0 deletions integration_tests/test_list_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from lpython import i32

def test_list_count():
i: i32
x: list[i32] = []

for i in range(-5, 0):
assert x.count(i) == 0
x.append(i)
assert x.count(i) == 1
x.append(i)
assert x.count(i) == 2
x.remove(i)
assert x.count(i) == 1

assert x == [-5, -4, -3, -2, -1]

for i in range(0, 5):
assert x.count(i) == 0
x.append(i)
assert x.count(i) == 1

assert x == [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]

while len(x)>0:
i = x[-1]
x.remove(i)
assert x.count(i) == 0

assert len(x) == 0
assert x.count(0) == 0

test_list_count()
1 change: 1 addition & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ expr
| ListLen(expr arg, ttype type, expr? value)
| ListConcat(expr left, expr right, ttype type, expr? value)
| ListCompare(expr left, cmpop op, expr right, ttype type, expr? value)
| ListCount(expr arg, expr ele, ttype type, expr? value)

| SetConstant(expr* elements, ttype type)
| SetLen(expr arg, ttype type, expr? value)
Expand Down
14 changes: 14 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
list_api->remove(plist, item, asr_el_type, *module);
}

void visit_ListCount(const ASR::ListCount_t& x) {
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_arg));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_arg);
llvm::Value* plist = tmp;

ptr_loads = !LLVM::is_llvm_struct(asr_el_type);
this->visit_expr_wrapper(x.m_ele, true);
ptr_loads = ptr_loads_copy;
llvm::Value *item = tmp;
tmp = list_api->count(plist, item, asr_el_type, *module);
}

void visit_ListClear(const ASR::ListClear_t& x) {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
Expand Down
77 changes: 77 additions & 0 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2498,6 +2498,83 @@ namespace LCompilers {
return LLVM::CreateLoad(*builder, i);
}

llvm::Value* LLVMList::count(llvm::Value* list, llvm::Value* item,
ASR::ttype_t* item_type, llvm::Module& module) {
llvm::Type* pos_type = llvm::Type::getInt32Ty(context);
llvm::Value* current_end_point = LLVM::CreateLoad(*builder,
get_pointer_to_current_end_point(list));
llvm::AllocaInst *i = builder->CreateAlloca(pos_type, nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
context, llvm::APInt(32, 0)), i);
llvm::AllocaInst *cnt = builder->CreateAlloca(pos_type, nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
context, llvm::APInt(32, 0)), cnt);
llvm::Value* tmp = nullptr;

/* Equivalent in C++:
* int i = 0;
* int cnt = 0;
* while(end_point > i) {
* if(list[i] == item) {
* tmp = cnt+1;
* cnt = tmp;
* }
* tmp = i+1;
* i = tmp;
* }
*/

llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value *cond = builder->CreateICmpSGT(current_end_point,
LLVM::CreateLoad(*builder, i));
builder->CreateCondBr(cond, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
// if occurrence found, increment cnt
llvm::Function *fn = builder->GetInsertBlock()->getParent();
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");

llvm::Value* left_arg = read_item(list, LLVM::CreateLoad(*builder, i),
false, module, LLVM::is_llvm_struct(item_type));
llvm::Value* cond = llvm_utils->is_equal_by_value(left_arg, item, module, item_type);
builder->CreateCondBr(cond, thenBB, elseBB);
builder->SetInsertPoint(thenBB);
{
tmp = builder->CreateAdd(
LLVM::CreateLoad(*builder, cnt),
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, tmp, cnt);
}
builder->CreateBr(mergeBB);

llvm_utils->start_new_block(elseBB);
llvm_utils->start_new_block(mergeBB);

// increment i
tmp = builder->CreateAdd(
LLVM::CreateLoad(*builder, i),
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, tmp, i);
}
builder->CreateBr(loophead);

// end
llvm_utils->start_new_block(loopend);

return LLVM::CreateLoad(*builder, cnt);
}

void LLVMList::remove(llvm::Value* list, llvm::Value* item,
ASR::ttype_t* item_type, llvm::Module& module) {
llvm::Type* pos_type = llvm::Type::getInt32Ty(context);
Expand Down
3 changes: 3 additions & 0 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ namespace LCompilers {
llvm::Value* item, ASR::ttype_t* item_type,
llvm::Module& module);

llvm::Value* count(llvm::Value* list, llvm::Value* item,
ASR::ttype_t* item_type, llvm::Module& module);

void free_data(llvm::Value* list, llvm::Module& module);

llvm::Value* check_list_equality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type,
Expand Down
27 changes: 27 additions & 0 deletions src/lpython/semantics/python_attribute_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct AttributeHandler {
{"int@bit_length", &eval_int_bit_length},
{"list@append", &eval_list_append},
{"list@remove", &eval_list_remove},
{"list@count", &eval_list_count},
{"list@clear", &eval_list_clear},
{"list@insert", &eval_list_insert},
{"list@pop", &eval_list_pop},
Expand Down Expand Up @@ -122,6 +123,32 @@ struct AttributeHandler {
return make_ListRemove_t(al, loc, s, args[0]);
}

static ASR::asr_t* eval_list_count(ASR::expr_t *s, Allocator &al, const Location &loc,
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
if (args.size() != 1) {
throw SemanticError("count() takes exactly one argument",
loc);
}
ASR::ttype_t *type = ASRUtils::expr_type(s);
ASR::ttype_t *list_type = ASR::down_cast<ASR::List_t>(type)->m_type;
ASR::ttype_t *ele_type = ASRUtils::expr_type(args[0]);
if (!ASRUtils::check_equal_type(ele_type, list_type)) {
std::string fnd = ASRUtils::type_to_str_python(ele_type);
std::string org = ASRUtils::type_to_str_python(list_type);
diag.add(diag::Diagnostic(
"Type mismatch in 'count', the types must be compatible",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("type mismatch (found: '" + fnd + "', expected: '" + org + "')",
{args[0]->base.loc})
})
);
throw SemanticAbort();
}
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc,
4, nullptr, 0));
return make_ListCount_t(al, loc, s, args[0], to_type, nullptr);
}

static ASR::asr_t* eval_list_insert(ASR::expr_t *s, Allocator &al, const Location &loc,
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
if (args.size() != 2) {
Expand Down
6 changes: 6 additions & 0 deletions tests/errors/test_list_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from lpython import i32

def test_list_count_error():
a: list[i32]
a = [1, 2, 3]
a.count(1.0)
13 changes: 13 additions & 0 deletions tests/reference/asr-test_list_count-4b42498.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-test_list_count-4b42498",
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/errors/test_list_count.py",
"infile_hash": "01975bd7c4bba02fd811de536b218167da99b532fa955b7bf8339779",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "asr-test_list_count-4b42498.stderr",
"stderr_hash": "f26efcc623b68ca43ef871eb01c8e3cbd1ae464baaa491c6e4969696",
"returncode": 2
}
5 changes: 5 additions & 0 deletions tests/reference/asr-test_list_count-4b42498.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
semantic error: Type mismatch in 'count', the types must be compatible
--> tests/errors/test_list_count.py:6:13
|
6 | a.count(1.0)
| ^^^ type mismatch (found: 'f64', expected: 'i32')
8 changes: 8 additions & 0 deletions tests/tests.toml
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,14 @@ asr = true
filename = "errors/test_list_concat.py"
asr = true

[[test]]
filename = "errors/test_list_count.py"
asr = true

[[test]]
filename = "../integration_tests/test_list_count.py"
asr_json = true

[[test]]
filename = "errors/test_list1.py"
asr = true
Expand Down

0 comments on commit ffde7ed

Please sign in to comment.