diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 2a08363d08..fd6cbc095d 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -457,6 +457,7 @@ RUN(NAME test_list_repeat LABELS cpython llvm NOFAST) RUN(NAME test_list_reverse LABELS cpython llvm) RUN(NAME test_list_pop LABELS cpython llvm NOFAST) # TODO: Remove NOFAST from here. RUN(NAME test_list_pop2 LABELS cpython llvm NOFAST) # TODO: Remove NOFAST from here. +RUN(NAME test_list_compare LABELS cpython llvm) RUN(NAME test_tuple_01 LABELS cpython llvm c) RUN(NAME test_tuple_02 LABELS cpython llvm c NOFAST) RUN(NAME test_tuple_03 LABELS cpython llvm c) diff --git a/integration_tests/test_list_compare.py b/integration_tests/test_list_compare.py new file mode 100644 index 0000000000..24fcc485d5 --- /dev/null +++ b/integration_tests/test_list_compare.py @@ -0,0 +1,53 @@ +from lpython import i32, f64 + +def test_list_compare(): + l1: list[i32] = [1, 2, 3] + l2: list[i32] = [1, 2, 3, 4] + l3: list[tuple[i32, f64, str]] = [(1, 2.0, 'a'), (3, 4.0, 'b')] + l4: list[tuple[i32, f64, str]] = [(1, 3.0, 'a')] + l5: list[list[str]] = [[''], ['']] + l6: list[str] = [] + l7: list[str] = [] + t1: tuple[i32, i32] + t2: tuple[i32, i32] + i: i32 + + assert l1 < l2 and l1 <= l2 + assert not l1 > l2 and not l1 >= l2 + i = l2.pop() + i = l2.pop() + assert l2 < l1 and l1 > l2 and l1 >= l2 + assert not (l1 < l2) + + l1 = [3, 4, 5] + l2 = [1, 6, 7] + assert l1 > l2 and l1 >= l2 + assert not l1 < l2 and not l1 <= l2 + + l1 = l2 + assert l1 == l2 and l1 <= l2 and l1 >= l2 + assert not l1 < l2 and not l1 > l2 + + assert l4 > l3 and l4 >= l3 + l4[0] = l3[0] + assert l4 < l3 + + for i in range(0, 10): + if i % 2 == 0: + l6.append('a') + else: + l7.append('a') + l5[0] = l6 + l5[1] = l7 + if i % 2 == 0: + assert l5[1 - i % 2] < l5[i % 2] + assert l5[1 - i % 2] <= l5[i % 2] + assert not l5[1 - i % 2] > l5[i % 2] + assert not l5[1 - i % 2] >= l5[i % 2] + + t1 = (1, 2) + t2 = (2, 3) + assert t1 < t2 and t1 <= t2 + assert not t1 > t2 and not t1 >= t2 + +test_list_compare() \ No newline at end of file diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 1170bb491f..3cd6826cad 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1427,10 +1427,31 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_right); llvm::Value* right = tmp; ptr_loads = ptr_loads_copy; - tmp = llvm_utils->is_equal_by_value(left, right, *module, - ASRUtils::expr_type(x.m_left)); - if (x.m_op == ASR::cmpopType::NotEq) { - tmp = builder->CreateNot(tmp); + + ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4)); + + if(x.m_op == ASR::cmpopType::Eq || x.m_op == ASR::cmpopType::NotEq) { + tmp = llvm_utils->is_equal_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left)); + if (x.m_op == ASR::cmpopType::NotEq) { + tmp = builder->CreateNot(tmp); + } + } + else if(x.m_op == ASR::cmpopType::Lt) { + tmp = llvm_utils->is_ineq_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left), 0, int32_type); + } + else if(x.m_op == ASR::cmpopType::LtE) { + tmp = llvm_utils->is_ineq_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left), 1, int32_type); + } + else if(x.m_op == ASR::cmpopType::Gt) { + tmp = llvm_utils->is_ineq_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left), 2, int32_type); + } + else if(x.m_op == ASR::cmpopType::GtE) { + tmp = llvm_utils->is_ineq_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left), 3, int32_type); } } @@ -1761,10 +1782,28 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_right); llvm::Value* right = tmp; ptr_loads = ptr_loads_copy; - tmp = llvm_utils->is_equal_by_value(left, right, *module, - ASRUtils::expr_type(x.m_left)); - if (x.m_op == ASR::cmpopType::NotEq) { - tmp = builder->CreateNot(tmp); + if(x.m_op == ASR::cmpopType::Eq || x.m_op == ASR::cmpopType::NotEq) { + tmp = llvm_utils->is_equal_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left)); + if (x.m_op == ASR::cmpopType::NotEq) { + tmp = builder->CreateNot(tmp); + } + } + else if(x.m_op == ASR::cmpopType::Lt) { + tmp = llvm_utils->is_ineq_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left), 0); + } + else if(x.m_op == ASR::cmpopType::LtE) { + tmp = llvm_utils->is_ineq_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left), 1); + } + else if(x.m_op == ASR::cmpopType::Gt) { + tmp = llvm_utils->is_ineq_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left), 2); + } + else if(x.m_op == ASR::cmpopType::GtE) { + tmp = llvm_utils->is_ineq_by_value(left, right, *module, + ASRUtils::expr_type(x.m_left), 3); } } diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 0599e97130..4c2a241a84 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1452,6 +1452,153 @@ namespace LCompilers { } } + llvm::Value* LLVMUtils::is_ineq_by_value(llvm::Value* left, llvm::Value* right, + llvm::Module& module, ASR::ttype_t* asr_type, + int8_t overload_id, ASR::ttype_t* int32_type) { + /** + * overloads: + * 0 < + * 1 <= + * 2 > + * 3 >= + */ + llvm::CmpInst::Predicate pred; + + switch( asr_type->type ) { + case ASR::ttypeType::Integer: + case ASR::ttypeType::Logical: { + switch( overload_id ) { + case 0: { + pred = llvm::CmpInst::Predicate::ICMP_SLT; + break; + } + case 1: { + pred = llvm::CmpInst::Predicate::ICMP_SLE; + break; + } + case 2: { + pred = llvm::CmpInst::Predicate::ICMP_SGT; + break; + } + case 3: { + pred = llvm::CmpInst::Predicate::ICMP_SGE; + break; + } + default: { + // can exit with error + } + } + return builder->CreateCmp(pred, left, right); + } + case ASR::ttypeType::Real: { + switch( overload_id ) { + case 0: { + pred = llvm::CmpInst::Predicate::FCMP_OLT; + break; + } + case 1: { + pred = llvm::CmpInst::Predicate::FCMP_OLE; + break; + } + case 2: { + pred = llvm::CmpInst::Predicate::FCMP_OGT; + break; + } + case 3: { + pred = llvm::CmpInst::Predicate::FCMP_OGE; + break; + } + default: { + // can exit with error + } + } + return builder->CreateCmp(pred, left, right); + } + case ASR::ttypeType::Character: { + if( !are_iterators_set ) { + str_cmp_itr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* idx = str_cmp_itr; + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), + idx); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + llvm::Value *cond = builder->CreateAnd( + builder->CreateICmpNE(l, null_char), + builder->CreateICmpNE(r, null_char) + ); + switch( overload_id ) { + case 0: { + pred = llvm::CmpInst::Predicate::ICMP_ULT; + break; + } + case 1: { + pred = llvm::CmpInst::Predicate::ICMP_ULE; + break; + } + case 2: { + pred = llvm::CmpInst::Predicate::ICMP_UGT; + break; + } + case 3: { + pred = llvm::CmpInst::Predicate::ICMP_UGE; + break; + } + default: { + // can exit with error + } + } + cond = builder->CreateAnd(cond, builder->CreateCmp(pred, l, r)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + start_new_block(loopbody); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, i, idx); + } + + builder->CreateBr(loophead); + + // end + start_new_block(loopend); + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* l = LLVM::CreateLoad(*builder, create_ptr_gep(left, i)); + llvm::Value* r = LLVM::CreateLoad(*builder, create_ptr_gep(right, i)); + return builder->CreateICmpULT(l, r); + } + case ASR::ttypeType::Tuple: { + ASR::Tuple_t* tuple_type = ASR::down_cast(asr_type); + return tuple_api->check_tuple_inequality(left, right, tuple_type, context, + builder, module, overload_id); + } + case ASR::ttypeType::List: { + ASR::List_t* list_type = ASR::down_cast(asr_type); + return list_api->check_list_inequality(left, right, list_type->m_type, + context, builder, module, + overload_id, int32_type); + } + default: { + throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " + + ASRUtils::type_to_str_python(asr_type)); + } + } + } + void LLVMUtils::deepcopy(llvm::Value* src, llvm::Value* dest, ASR::ttype_t* asr_type, llvm::Module* module, std::map>& name2memidx) { @@ -4276,6 +4423,101 @@ namespace LCompilers { return LLVM::CreateLoad(*builder, is_equal); } + llvm::Value* LLVMList::check_list_inequality(llvm::Value* l1, llvm::Value* l2, + ASR::ttype_t* item_type, + llvm::LLVMContext& context, + llvm::IRBuilder<>* builder, + llvm::Module& module, int8_t overload_id, + ASR::ttype_t* int32_type) { + /** + * Equivalent in C++ + * + * equality_holds = 1; + * inequality_holds = 0; + * i = 0; + * + * while( i < a_len && i < b_len && equality_holds ) { + * equality_holds &= (a[i] == b[i]); + * inequality_holds |= (a[i] op b[i]); + * i++; + * } + * + * if( (i == a_len || i == b_len) && equality_holds ) { + * inequality_holds = a_len op b_len; + * } + * + */ + + llvm::AllocaInst *equality_holds = builder->CreateAlloca( + llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 1)), + equality_holds); + llvm::AllocaInst *inequality_holds = builder->CreateAlloca( + llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 0)), + inequality_holds); + + llvm::Value *a_len = llvm_utils->list_api->len(l1); + llvm::Value *b_len = llvm_utils->list_api->len(l2); + llvm::AllocaInst *idx = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get( + context, llvm::APInt(32, 0)), idx); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* cnd = builder->CreateICmpSLT(i, a_len); + cnd = builder->CreateAnd(cnd, builder->CreateICmpSLT(i, b_len)); + cnd = builder->CreateAnd(cnd, LLVM::CreateLoad(*builder, equality_holds)); + builder->CreateCondBr(cnd, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* i = LLVM::CreateLoad(*builder, idx); + llvm::Value* left_arg = llvm_utils->list_api->read_item(l1, i, + false, module, LLVM::is_llvm_struct(item_type)); + llvm::Value* right_arg = llvm_utils->list_api->read_item(l2, i, + false, module, LLVM::is_llvm_struct(item_type)); + llvm::Value* res = llvm_utils->is_ineq_by_value(left_arg, right_arg, module, + item_type, overload_id); + res = builder->CreateOr(LLVM::CreateLoad(*builder, inequality_holds), res); + LLVM::CreateStore(*builder, res, inequality_holds); + res = llvm_utils->is_equal_by_value(left_arg, right_arg, module, + item_type); + res = builder->CreateAnd(LLVM::CreateLoad(*builder, equality_holds), res); + LLVM::CreateStore(*builder, res, equality_holds); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, i, idx); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + + llvm::Value* cond = builder->CreateICmpEQ(LLVM::CreateLoad(*builder, idx), + a_len); + cond = builder->CreateOr(cond, builder->CreateICmpEQ( + LLVM::CreateLoad(*builder, idx), b_len)); + cond = builder->CreateAnd(cond, LLVM::CreateLoad(*builder, equality_holds)); + llvm_utils->create_if_else(cond, [&]() { + LLVM::CreateStore(*builder, llvm_utils->is_ineq_by_value(a_len, b_len, + module, int32_type, overload_id), inequality_holds); + }, []() { + // LLVM::CreateStore(*builder, llvm::ConstantInt::get( + // context, llvm::APInt(1, 0)), inequality_holds); + }); + + return LLVM::CreateLoad(*builder, inequality_holds); + } + void LLVMList::list_repeat_copy(llvm::Value* repeat_list, llvm::Value* init_list, llvm::Value* num_times, llvm::Value* init_list_len, llvm::Module* module) { @@ -4421,6 +4663,58 @@ namespace LCompilers { return is_equal; } + llvm::Value* LLVMTuple::check_tuple_inequality(llvm::Value* t1, llvm::Value* t2, + ASR::Tuple_t* tuple_type, + llvm::LLVMContext& context, + llvm::IRBuilder<>* builder, + llvm::Module& module, int8_t overload_id) { + /** + * Equivalent in C++ + * + * equality_holds = 1; + * inequality_holds = 0; + * i = 0; + * + * // owing to compile-time access of indices, + * // loop is unrolled into multiple if statements + * while( i < a_len && equality_holds ) { + * inequality_holds |= (a[i] op b[i]); + * equality_holds &= (a[i] == b[i]); + * i++; + * } + * + * return inequality_holds; + * + */ + + llvm::AllocaInst *equality_holds = builder->CreateAlloca( + llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 1)), + equality_holds); + llvm::AllocaInst *inequality_holds = builder->CreateAlloca( + llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(context, llvm::APInt(1, 0)), + inequality_holds); + + for( size_t i = 0; i < tuple_type->n_type; i++ ) { + llvm_utils->create_if_else(LLVM::CreateLoad(*builder, equality_holds), [&]() { + llvm::Value* t1i = llvm_utils->tuple_api->read_item(t1, i, LLVM::is_llvm_struct( + tuple_type->m_type[i])); + llvm::Value* t2i = llvm_utils->tuple_api->read_item(t2, i, LLVM::is_llvm_struct( + tuple_type->m_type[i])); + llvm::Value* res = llvm_utils->is_ineq_by_value(t1i, t2i, module, + tuple_type->m_type[i], overload_id); + res = builder->CreateOr(LLVM::CreateLoad(*builder, inequality_holds), res); + LLVM::CreateStore(*builder, res, inequality_holds); + res = llvm_utils->is_equal_by_value(t1i, t2i, module, tuple_type->m_type[i]); + res = builder->CreateAnd(LLVM::CreateLoad(*builder, equality_holds), res); + LLVM::CreateStore(*builder, res, equality_holds); + }, [](){}); + } + + return LLVM::CreateLoad(*builder, inequality_holds); + } + void LLVMTuple::concat(llvm::Value* t1, llvm::Value* t2, ASR::Tuple_t* tuple_type_1, ASR::Tuple_t* tuple_type_2, llvm::Value* concat_tuple, ASR::Tuple_t* concat_tuple_type, llvm::Module& module, diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index fcdfaddeb1..54bbaef366 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -220,6 +220,10 @@ namespace LCompilers { llvm::Value* is_equal_by_value(llvm::Value* left, llvm::Value* right, llvm::Module& module, ASR::ttype_t* asr_type); + llvm::Value* is_ineq_by_value(llvm::Value* left, llvm::Value* right, + llvm::Module& module, ASR::ttype_t* asr_type, + int8_t overload_id, ASR::ttype_t* int32_type=nullptr); + void set_iterators(); void reset_iterators(); @@ -413,6 +417,11 @@ namespace LCompilers { llvm::Value* check_list_equality(llvm::Value* l1, llvm::Value* l2, ASR::ttype_t *item_type, llvm::LLVMContext& context, llvm::IRBuilder<>* builder, llvm::Module& module); + llvm::Value* check_list_inequality(llvm::Value* l1, llvm::Value* l2, + ASR::ttype_t *item_type, llvm::LLVMContext& context, + llvm::IRBuilder<>* builder, llvm::Module& module, + int8_t overload_id, ASR::ttype_t* int32_type=nullptr); + void list_repeat_copy(llvm::Value* repeat_list, llvm::Value* init_list, llvm::Value* num_times, llvm::Value* init_list_len, llvm::Module* module); @@ -454,6 +463,10 @@ namespace LCompilers { ASR::Tuple_t* tuple_type, llvm::LLVMContext& context, llvm::IRBuilder<>* builder, llvm::Module& module); + llvm::Value* check_tuple_inequality(llvm::Value* t1, llvm::Value* t2, + ASR::Tuple_t* tuple_type, llvm::LLVMContext& context, + llvm::IRBuilder<>* builder, llvm::Module& module, int8_t overload_id); + void concat(llvm::Value* t1, llvm::Value* t2, ASR::Tuple_t* tuple_type_1, ASR::Tuple_t* tuple_type_2, llvm::Value* concat_tuple, ASR::Tuple_t* concat_tuple_type, llvm::Module& module, diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 2f876f394f..1611d23d04 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6303,14 +6303,20 @@ class BodyVisitor : public CommonVisitor { tmp = ASR::make_StringCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); } else if (ASR::is_a(*dest_type)) { - if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq) { - throw SemanticError("Only Equal and Not-equal operators are supported for Tuples", + if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq + && asr_op != ASR::cmpopType::Lt && asr_op != ASR::cmpopType::LtE + && asr_op != ASR::cmpopType::Gt && asr_op != ASR::cmpopType::GtE) { + throw SemanticError("Only ==, !=, <, <=, >, >= operators " + "are supported for Tuples", x.base.base.loc); } tmp = ASR::make_TupleCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); } else if (ASR::is_a(*dest_type)) { - if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq) { - throw SemanticError("Only Equal and Not-equal operators are supported for Tuples", + if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq + && asr_op != ASR::cmpopType::Lt && asr_op != ASR::cmpopType::LtE + && asr_op != ASR::cmpopType::Gt && asr_op != ASR::cmpopType::GtE) { + throw SemanticError("Only ==, !=, <, <=, >, >= operators " + "are supported for Lists", x.base.base.loc); } tmp = ASR::make_ListCompare_t(al, x.base.base.loc, left, asr_op, right, type, value);