Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add list comparison #2025

Merged
merged 6 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions integration_tests/test_list_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ def test_list_compare():
l5: list[list[str]] = [[''], ['']]
l6: list[str] = []
l7: list[str] = []
t1: tuple[i32, i32]
t2: tuple[i32, i32]
i: i32

assert l1 < l2
Expand All @@ -16,9 +18,13 @@ def test_list_compare():
assert l2 < l1
assert not (l1 < l2)

assert l3 < l4
l4[0] = l3[0]
assert l4 < l3
l1 = [3,4,5]
l2 = [1,6,7]
assert l2 < l1

# assert l3 < l4
# l4[0] = l3[0]
# assert l4 < l3

for i in range(0, 10):
if i % 2 == 0:
Expand All @@ -29,5 +35,9 @@ def test_list_compare():
l5[1] = l7
if i % 2 == 0:
assert l5[1 - i % 2] < l5[i % 2]

# t1 = (1, 2)
# t2 = (3, 4)
# assert t1 < t2

test_list_compare()
103 changes: 90 additions & 13 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3189,7 +3189,26 @@ namespace LCompilers {
// TODO:
// - ineq operations other than "<"
// - abstract out this code, possibly switch over operators
// - short-circuit. Without initial allocation of res? Also for equality
// - short-circuit without initial allocation of res? Also for equality

/**
* Equivalent in C++
* For "<"
*
* equality_holds = 1;
* inequality_holds = 0;
* i = 0;
*
* while( i < a_len && i < b_len && equality_holds ) {
* equality_holds &= (a[i] == b[i]);
* inequality_holds |= (a[i] < b[i]);
* }
*
* if( i == a_len && a_len < b_len && equality_holds ) {
* inequality_holds = 1;
* }
*
*/

llvm::AllocaInst *equality_holds = builder->CreateAlloca(
llvm::Type::getInt1Ty(context), nullptr);
Expand All @@ -3215,6 +3234,7 @@ namespace LCompilers {
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
llvm::Value* cnd = builder->CreateICmpSLT(i, a_len);
cnd = builder->CreateAnd(cnd, builder->CreateICmpSLT(i, b_len));
cnd = builder->CreateAnd(cnd, LLVM::CreateLoad(*builder, equality_holds));
builder->CreateCondBr(cnd, loopbody, loopend);
}

Expand Down Expand Up @@ -3244,8 +3264,9 @@ namespace LCompilers {
// end
llvm_utils->start_new_block(loopend);

// if a_len < b_len && equality_holds, then left < right
llvm::Value* cond = builder->CreateICmpSLT(a_len, b_len);
llvm::Value* cond = builder->CreateICmpEQ(LLVM::CreateLoad(*builder, idx),
a_len);
cond = builder->CreateAnd(cond, builder->CreateICmpSLT(a_len, b_len));
cond = builder->CreateAnd(cond, LLVM::CreateLoad(*builder, equality_holds));
llvm_utils->create_if_else(cond, [&]() {
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
Expand Down Expand Up @@ -3408,17 +3429,73 @@ namespace LCompilers {
llvm::IRBuilder<>* builder,
llvm::Module& module) {
// TODO: operators other than "<"
llvm::Value* inequality_holds = llvm::ConstantInt::get(context, llvm::APInt(1, 0));
for( size_t i = 0; i < tuple_type->n_type; i++ ) {
llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i, LLVM::is_llvm_struct(
tuple_type->m_type[i]));
llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i, LLVM::is_llvm_struct(
tuple_type->m_type[i]));
llvm::Value* res = llvm_utils->is_less_by_value(t1i, t2i, module,
tuple_type->m_type[i]);
inequality_holds = builder->CreateOr(inequality_holds, res);

/**
* Equivalent in C++
* For "<"
*
* equality_holds = 1;
* inequality_holds = 0;
* i = 0;
*
* while( i < a_len && equality_holds ) {
* equality_holds &= (a[i] == b[i]);
* inequality_holds |= (a[i] < b[i]);
* }
*
*/

llvm::AllocaInst *equality_holds = builder->CreateAlloca(
llvm::Type::getInt1Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 1)),
equality_holds);
llvm::AllocaInst *inequality_holds = builder->CreateAlloca(
llvm::Type::getInt1Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 0)),
inequality_holds);

llvm::AllocaInst *idx = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
context, llvm::APInt(32, 0)), idx);
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
llvm::Value* cnd = builder->CreateICmpSLT(i, llvm::ConstantInt::get(
context, llvm::APInt(32, tuple_type->n_type)));
cnd = builder->CreateAnd(cnd, LLVM::CreateLoad(*builder, equality_holds));
builder->CreateCondBr(cnd, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value* i = LLVM::CreateLoad(*builder, idx);
// llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i, LLVM::is_llvm_struct(
// tuple_type->m_type[i]));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we index tuple_type->m_type with run time value i?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't. :-). Tuples should only be indexed with fixed indices available at compile time so that type of the indexed item can be figured out at compile time itself. Its the pattern as C++ std::tuple.

// llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i, LLVM::is_llvm_struct(
// tuple_type->m_type[i]));
// llvm::Value* res = llvm_utils->is_less_by_value(t1i, t2i, module,
// tuple_type->m_type[i]);
// res = builder->CreateOr(LLVM::CreateLoad(*builder, inequality_holds), res);
// LLVM::CreateStore(*builder, res, inequality_holds);
// res = llvm_utils->is_equal_by_value(t1i, t2i, module, tuple_type->m_type[i]);
// res = builder->CreateAnd(LLVM::CreateLoad(*builder, equality_holds), res);
// LLVM::CreateStore(*builder, res, equality_holds);
i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, i, idx);
}
return inequality_holds;

builder->CreateBr(loophead);

// end
llvm_utils->start_new_block(loopend);
return LLVM::CreateLoad(*builder, inequality_holds);
}

void LLVMTuple::concat(llvm::Value* t1, llvm::Value* t2, ASR::Tuple_t* tuple_type_1,
Expand Down
9 changes: 6 additions & 3 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6208,15 +6208,18 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {

tmp = ASR::make_StringCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
} else if (ASR::is_a<ASR::Tuple_t>(*dest_type)) {
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq) {
throw SemanticError("Only Equal and Not-equal operators are supported for Tuples",
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq
&& asr_op != ASR::cmpopType::Lt) {
throw SemanticError("Only Equal, Not-equal and Less-than operators "
"are supported for Tuples",
x.base.base.loc);
}
tmp = ASR::make_TupleCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
} else if (ASR::is_a<ASR::List_t>(*dest_type)) {
if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq
&& asr_op != ASR::cmpopType::Lt) {
throw SemanticError("Only Equal, Not-equal and Less-than operators are supported for Lists",
throw SemanticError("Only Equal, Not-equal and Less-than operators "
"are supported for Lists",
x.base.base.loc);
}
tmp = ASR::make_ListCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);
Expand Down
Loading