Skip to content

Commit

Permalink
Merge pull request lcompilers#2213 from Smit-create/i-2131
Browse files Browse the repository at this point in the history
ASR: Support chained CompareOp
  • Loading branch information
certik committed Jul 29, 2023
2 parents 94fb130 + fc92d7f commit 1b5bafd
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 72 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ RUN(NAME expr_18 FAIL LABELS cpython llvm c)
RUN(NAME expr_19 LABELS cpython llvm c)
RUN(NAME expr_20 LABELS cpython llvm c)
RUN(NAME expr_21 LABELS cpython llvm c)
RUN(NAME expr_22 LABELS cpython llvm c)

RUN(NAME expr_01u LABELS cpython llvm c NOFAST)
RUN(NAME expr_02u LABELS cpython llvm c NOFAST)
Expand Down Expand Up @@ -661,6 +662,7 @@ RUN(NAME structs_31 LABELS cpython llvm c)
RUN(NAME structs_32 LABELS cpython llvm c)
RUN(NAME structs_33 LABELS cpython llvm c)
RUN(NAME structs_34 LABELS cpython llvm c)
RUN(NAME structs_35 LABELS cpython llvm c)

RUN(NAME symbolics_01 LABELS cpython_sym c_sym)
RUN(NAME symbolics_02 LABELS cpython_sym c_sym)
Expand Down
22 changes: 22 additions & 0 deletions integration_tests/expr_22.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from lpython import i32

def f():
x: i32 = 2
y: i32 = 1
z: i32 = 1
t: i32 = 1
assert x > y == z
assert not (x == y == z)
assert y == z == t != x
assert x > y == z >= t
t = 0
assert x > y == z >= t
t = 4
assert not (x > y == z >= t)
assert t > x > y == z
assert 3 > 2 >= 0 <= 6
assert t > y < x
assert not (2 == 3 > 4)


f()
32 changes: 32 additions & 0 deletions integration_tests/structs_35.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from lpython import (i8, i32, i64, f32, f64,
dataclass
)
from numpy import (empty,
int8,
)

# test issue 2131

@dataclass
class Foo:
a : i8[4] = empty(4, dtype=int8)
dim : i32 = 4

def trinary_majority(x : Foo, y : Foo, z : Foo) -> Foo:
foo : Foo = Foo()

assert foo.dim == x.dim == y.dim == z.dim

return foo


t1 : Foo = Foo()
t1.a = empty(4, dtype=int8)

t2 : Foo = Foo()
t2.a = empty(4, dtype=int8)

t3 : Foo = Foo()
t3.a = empty(4, dtype=int8)

r1 : Foo = trinary_majority(t1, t2, t3)
11 changes: 9 additions & 2 deletions src/libasr/codegen/asr_to_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
std::string unit_src = "";
indentation_level = 0;
indentation_spaces = 4;
SymbolTable* current_scope_copy = current_scope;
current_scope = global_scope;
c_ds_api->set_indentation(indentation_level, indentation_spaces);
c_ds_api->set_global_scope(global_scope);
c_utils_functions->set_indentation(indentation_level, indentation_spaces);
Expand Down Expand Up @@ -760,6 +762,7 @@ R"(
out_file.close();
}
}
current_scope = current_scope_copy;
}

void visit_Module(const ASR::Module_t &x) {
Expand All @@ -768,7 +771,8 @@ R"(
} else {
intrinsic_module = false;
}

SymbolTable *current_scope_copy = current_scope;
current_scope = x.m_symtab;
std::string unit_src = "";
for (auto &item : x.m_symtab->get_scope()) {
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
Expand Down Expand Up @@ -813,13 +817,15 @@ R"(
}
src = unit_src;
intrinsic_module = false;
current_scope = current_scope_copy;
}

void visit_Program(const ASR::Program_t &x) {
// Topologically sort all program functions
// and then define them in the right order
std::vector<std::string> func_order = ASRUtils::determine_function_definition_order(x.m_symtab);

SymbolTable *current_scope_copy = current_scope;
current_scope = x.m_symtab;
// Generate code for nested subroutines and functions first:
std::string contains;
for (auto &item : func_order) {
Expand Down Expand Up @@ -898,6 +904,7 @@ R"( // Initialise Numpy
+ decl + body
+ indent1 + "return 0;\n}\n";
indentation_level -= 2;
current_scope = current_scope_copy;
}

template <typename T>
Expand Down
3 changes: 3 additions & 0 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>

void visit_TranslationUnit(const ASR::TranslationUnit_t &x) {
global_scope = x.m_global_scope;
SymbolTable* current_scope_copy = current_scope;
current_scope = global_scope;
// All loose statements must be converted to a function, so the items
// must be empty:
LCOMPILERS_ASSERT(x.n_items == 0);
Expand Down Expand Up @@ -255,6 +257,7 @@ R"(#include <stdio.h>
}

src = unit_src;
current_scope = current_scope_copy;
}

std::string check_tmp_buffer() {
Expand Down
Loading

0 comments on commit 1b5bafd

Please sign in to comment.