Skip to content

Commit

Permalink
Support runtime do loop increments in C backend (lcompilers#1423)
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Jan 16, 2023
1 parent d3a947f commit 995d972
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 70 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ RUN(NAME expr_13 LABELS llvm c
RUN(NAME expr_14 LABELS cpython llvm c)
RUN(NAME loop_01 LABELS cpython llvm c)
RUN(NAME loop_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME loop_03 LABELS cpython llvm c wasm)
RUN(NAME if_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME if_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
RUN(NAME print_02 LABELS cpython llvm)
Expand Down
56 changes: 56 additions & 0 deletions integration_tests/loop_03.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from ltypes import i32

def test_loop_01():
i: i32 = 0
j: i32 = 0
rt_inc: i32
rt_inc = 1

while False:
assert False

while i < 0:
assert False

while i < 10:
i += 1
assert i == 10

while i < 20:
while i < 15:
i += 1
i += 1
assert i == 20

for i in range(0, 5, rt_inc):
assert i == j
j += 1

def test_loop_02():
i: i32 = 0
j: i32 = 0
rt_inc_neg_1: i32 = -1
rt_inc_1: i32 = 1

j = 0
for i in range(10, 0, rt_inc_neg_1):
j = j + i
assert j == 55

for i in range(0, 5, rt_inc_1):
if i == 3:
break
assert i == 3

j = 0
for i in range(0, 5, rt_inc_1):
if i == 3:
continue
j += 1
assert j == 4

def verify():
test_loop_01()
test_loop_02()

verify()
72 changes: 48 additions & 24 deletions src/libasr/codegen/asr_to_c_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
std::unique_ptr<CCPPDSUtils> c_ds_api;
std::string const_name;
size_t const_vars_count;
size_t loop_end_count;

SymbolTable* current_scope;
bool is_string_concat_present;
Expand All @@ -95,7 +96,8 @@ class BaseCCPPVisitor : public ASR::BaseVisitor<Struct>
is_c{is_c}, global_scope{nullptr}, lower_bound{default_lower_bound},
template_number{0}, c_ds_api{std::make_unique<CCPPDSUtils>(is_c, platform)},
const_name{"constname"},
const_vars_count{0}, is_string_concat_present{false} {
const_vars_count{0}, loop_end_count{0},
is_string_concat_present{false} {
}

void visit_TranslationUnit(const ASR::TranslationUnit_t &x) {
Expand Down Expand Up @@ -1738,6 +1740,7 @@ R"(#include <stdio.h>
void visit_DoLoop(const ASR::DoLoop_t &x) {
std::string current_body_copy = current_body;
current_body = "";
std::string loop_end_decl = "";
std::string indent(indentation_level*indentation_spaces, ' ');
std::string out = indent + "for (";
ASR::Variable_t *loop_var = ASRUtils::EXPR2VAR(x.m_head.m_v);
Expand All @@ -1748,34 +1751,55 @@ R"(#include <stdio.h>
LCOMPILERS_ASSERT(a);
LCOMPILERS_ASSERT(b);
int increment;
bool is_c_constant = false;
if (!c) {
increment = 1;
is_c_constant = true;
} else {
c = ASRUtils::expr_value(c);
bool is_c_constant = ASRUtils::extract_value(c, increment);
if( !is_c_constant ) {
throw CodeGenError("Do loop increment type not supported");
}
ASR::expr_t* c_value = ASRUtils::expr_value(c);
is_c_constant = ASRUtils::extract_value(c_value, increment);
}
std::string cmp_op;
if (increment > 0) {
cmp_op = "<=";
} else {
cmp_op = ">=";
}

out += lvname + "=";
self().visit_expr(*a);
out += src + "; " + lvname + cmp_op;
self().visit_expr(*b);
out += src + "; " + lvname;
if (increment == 1) {
out += "++";
} else if (increment == -1) {
out += "--";

if( is_c_constant ) {
std::string cmp_op;
if (increment > 0) {
cmp_op = "<=";
} else {
cmp_op = ">=";
}

out += lvname + "=";
self().visit_expr(*a);
out += src + "; " + lvname + cmp_op;
self().visit_expr(*b);
out += src + "; " + lvname;
if (increment == 1) {
out += "++";
} else if (increment == -1) {
out += "--";
} else {
out += "+=" + std::to_string(increment);
}
} else {
out += "+=" + std::to_string(increment);
this->visit_expr(*c);
std::string increment_ = std::move(src);
self().visit_expr(*b);
std::string do_loop_end = std::move(src);
std::string do_loop_end_name = current_scope->get_unique_name(
"loop_end___" + std::to_string(loop_end_count));
loop_end_count += 1;
loop_end_decl = indent + CUtils::get_c_type_from_ttype_t(ASRUtils::expr_type(b), is_c) +
" " + do_loop_end_name + " = " + do_loop_end + ";\n";
out += lvname + " = ";
self().visit_expr(*a);
out += src + "; ";
out += "((" + increment_ + " >= 0) && (" +
lvname + " <= " + do_loop_end_name + ")) || (("
+ increment_ + " < 0) && (" + lvname + " >= "
+ do_loop_end_name + ")); " + lvname;
out += " += " + increment_;
}

out += ") {\n";
indentation_level += 1;
for (size_t i=0; i<x.n_body; i++) {
Expand All @@ -1785,7 +1809,7 @@ R"(#include <stdio.h>
out += current_body;
out += indent + "}\n";
indentation_level -= 1;
src = out;
src = loop_end_decl + out;
current_body = current_body_copy;
}

Expand Down
1 change: 1 addition & 0 deletions src/libasr/codegen/asr_to_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,7 @@ class ASRToWASMVisitor : public ASR::BaseVisitor<ASRToWASMVisitor> {
if (arg_kind > 0 && dest_kind > 0) {
if (arg_kind == 4 && dest_kind == 8) {
wasm::emit_i64_extend_i32_s(m_code_section, m_al);
} else if (arg_kind == 4 && dest_kind == 4) {
} else {
std::string msg = "Conversion from kinds " +
std::to_string(arg_kind) + " to " +
Expand Down
42 changes: 30 additions & 12 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4261,6 +4261,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
a_kind, nullptr, 0));
ASR::expr_t *constant_one = ASR::down_cast<ASR::expr_t>(ASR::make_IntegerConstant_t(
al, loc, 1, a_type));
ASR::expr_t *constant_neg_one = ASR::down_cast<ASR::expr_t>(ASR::make_IntegerConstant_t(
al, loc, -1, a_type));
ASR::expr_t *constant_zero = ASR::down_cast<ASR::expr_t>(ASR::make_IntegerConstant_t(
al, loc, 0, a_type));
if (!loop_start) {
loop_start = ASR::down_cast<ASR::expr_t>(ASR::make_IntegerConstant_t(al, loc, 0, a_type));
}
Expand Down Expand Up @@ -4305,20 +4309,34 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
ASR::expr_t *inc_value = ASRUtils::expr_value(inc);
int64_t inc_int = 1;
bool is_value_present = ASRUtils::extract_value(inc_value, inc_int);
if (!is_value_present) {
throw SemanticError("For loop increment should Compile time constant.", loc);
}

// Loop end depends upon the sign of m_increment.
// if inc > 0 then: loop_end -=1 else loop_end += 1
ASR::binopType offset_op;
if (inc_int < 0 ) {
offset_op = ASR::binopType::Add;
if (is_value_present) {
// Loop end depends upon the sign of m_increment.
// if inc > 0 then: loop_end -=1 else loop_end += 1
ASR::binopType offset_op;
if (inc_int < 0 ) {
offset_op = ASR::binopType::Add;
} else {
offset_op = ASR::binopType::Sub;
}
make_BinOp_helper(loop_end, constant_one,
offset_op, loc, false);
} else {
offset_op = ASR::binopType::Sub;
ASR::ttype_t* logical_type = ASRUtils::TYPE(ASR::make_Logical_t(al, inc->base.loc, 4, nullptr, 0));
ASR::expr_t* inc_pos = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, inc->base.loc, inc,
ASR::cmpopType::GtE, constant_zero, logical_type, nullptr));
ASR::expr_t* inc_neg = ASRUtils::EXPR(ASR::make_IntegerCompare_t(al, inc->base.loc, inc,
ASR::cmpopType::Lt, constant_zero, logical_type, nullptr));
cast_helper(a_type, inc_pos, inc->base.loc, true);
cast_helper(a_type, inc_neg, inc->base.loc, true);
make_BinOp_helper(inc_pos, constant_neg_one, ASR::binopType::Mul, inc->base.loc, false);
ASR::expr_t* case_1 = ASRUtils::EXPR(tmp);
make_BinOp_helper(inc_neg, constant_one, ASR::binopType::Mul, inc->base.loc, false);
ASR::expr_t* case_2 = ASRUtils::EXPR(tmp);
make_BinOp_helper(case_1, case_2, ASR::binopType::Add, inc->base.loc, false);
ASR::expr_t* cases_combined = ASRUtils::EXPR(tmp);
make_BinOp_helper(loop_end, cases_combined, ASR::binopType::Add, loop_end->base.loc, false);
}
make_BinOp_helper(loop_end, constant_one,
offset_op, loc, false);

head.m_end = ASRUtils::EXPR(tmp);


Expand Down
8 changes: 0 additions & 8 deletions tests/errors/test_for1.py

This file was deleted.

13 changes: 0 additions & 13 deletions tests/reference/asr-test_for1-260404e.json

This file was deleted.

9 changes: 0 additions & 9 deletions tests/reference/asr-test_for1-260404e.stderr

This file was deleted.

4 changes: 0 additions & 4 deletions tests/tests.toml
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,6 @@ asr = true
filename = "errors/test_tuple1.py"
asr = true

[[test]]
filename = "errors/test_for1.py"
asr = true

[[test]]
filename = "errors/test_for2.py"
asr = true
Expand Down

0 comments on commit 995d972

Please sign in to comment.