Skip to content

Commit

Permalink
Update ASR from LFortran
Browse files Browse the repository at this point in the history
  • Loading branch information
certik committed Apr 14, 2022
1 parent 1677c91 commit daa19a4
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 5 deletions.
File renamed without changes.
1 change: 1 addition & 0 deletions src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set(SRC
pass/sign_from_value.cpp
pass/inline_function_calls.cpp
pass/loop_unroll.cpp
pass/dead_code_removal.cpp

asr_verify.cpp
asr_utils.cpp
Expand Down
94 changes: 89 additions & 5 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,33 @@ static inline bool is_value_constant(ASR::expr_t *a_value) {
return true;
}

template <typename T>
static inline bool is_value_constant(ASR::expr_t *a_value, T& const_value) {
static inline bool is_value_constant(ASR::expr_t *a_value, int64_t& const_value) {
if( a_value == nullptr ) {
return false;
}
if (ASR::is_a<ASR::ConstantInteger_t>(*a_value)) {
ASR::ConstantInteger_t* const_int = ASR::down_cast<ASR::ConstantInteger_t>(a_value);
const_value = const_int->m_n;
} else {
return false;
}
return true;
}

static inline bool is_value_constant(ASR::expr_t *a_value, bool& const_value) {
if( a_value == nullptr ) {
return false;
}
if (ASR::is_a<ASR::ConstantLogical_t>(*a_value)) {
ASR::ConstantLogical_t* const_logical = ASR::down_cast<ASR::ConstantLogical_t>(a_value);
const_value = const_logical->m_value;
} else {
return false;
}
return true;
}

static inline bool is_value_constant(ASR::expr_t *a_value, double& const_value) {
if( a_value == nullptr ) {
return false;
}
Expand All @@ -450,15 +475,74 @@ static inline bool is_value_constant(ASR::expr_t *a_value, T& const_value) {
} else if (ASR::is_a<ASR::ConstantReal_t>(*a_value)) {
ASR::ConstantReal_t* const_real = ASR::down_cast<ASR::ConstantReal_t>(a_value);
const_value = const_real->m_r;
} else if (ASR::is_a<ASR::ConstantLogical_t>(*a_value)) {
ASR::ConstantLogical_t* const_logical = ASR::down_cast<ASR::ConstantLogical_t>(a_value);
const_value = const_logical->m_value;
} else {
return false;
}
return true;
}

static inline bool is_value_constant(ASR::expr_t *a_value, std::string& const_value) {
if( a_value == nullptr ) {
return false;
}
if (ASR::is_a<ASR::ConstantString_t>(*a_value)) {
ASR::ConstantString_t* const_string = ASR::down_cast<ASR::ConstantString_t>(a_value);
const_value = std::string(const_string->m_s);
} else {
return false;
}
return true;
}

static inline bool is_value_equal(ASR::expr_t* test_expr, ASR::expr_t* desired_expr) {
ASR::expr_t* test_value = expr_value(test_expr);
ASR::expr_t* desired_value = expr_value(desired_expr);
if( !is_value_constant(test_value) ||
!is_value_constant(desired_value) ||
test_value->type != desired_value->type ) {
return false;
}

switch( desired_value->type ) {
case ASR::exprType::ConstantInteger: {
ASR::ConstantInteger_t* test_int = ASR::down_cast<ASR::ConstantInteger_t>(test_value);
ASR::ConstantInteger_t* desired_int = ASR::down_cast<ASR::ConstantInteger_t>(desired_value);
return test_int->m_n == desired_int->m_n;
}
case ASR::exprType::ConstantString: {
ASR::ConstantString_t* test_str = ASR::down_cast<ASR::ConstantString_t>(test_value);
ASR::ConstantString_t* desired_str = ASR::down_cast<ASR::ConstantString_t>(desired_value);
return std::string(test_str->m_s) == std::string(desired_str->m_s);
}
default: {
return false;
}
}
}

static inline bool is_value_in_range(ASR::expr_t* start, ASR::expr_t* end, ASR::expr_t* value) {
ASR::expr_t *start_value = nullptr, *end_value = nullptr;
if( start ) {
start_value = expr_value(start);
}
if( end ) {
end_value = expr_value(end);
}
ASR::expr_t* test_value = expr_value(value);


double start_double = std::numeric_limits<double>::min();
double end_double = std::numeric_limits<double>::max();
double value_double;
bool start_const = is_value_constant(start_value, start_double);
bool end_const = is_value_constant(end_value, end_double);
bool value_const = is_value_constant(test_value, value_double);
if( !value_const || (!start_const && !end_const) ) {
return false;
}
return value_double >= start_double && value_double <= end_double;
}

// Returns true if all arguments are evaluated
static inline bool all_args_evaluated(const Vec<ASR::expr_t*> &args) {
for (auto &a : args) {
Expand Down
6 changes: 6 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
#include <libasr/pass/class_constructor.h>
#include <libasr/pass/unused_functions.h>
#include <libasr/pass/inline_function_calls.h>
#include <libasr/pass/dead_code_removal.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>
#include <libasr/codegen/llvm_utils.h>
Expand Down Expand Up @@ -4272,6 +4273,11 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,

pass_replace_do_loops(al, asr);
pass_replace_forall(al, asr);

if( fast ) {
pass_dead_code_removal(al, asr, rl_path);
}

pass_replace_select_case(al, asr);
pass_unused_functions(al, asr);

Expand Down
109 changes: 109 additions & 0 deletions src/libasr/pass/dead_code_removal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include <libasr/asr.h>
#include <libasr/containers.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>
#include <libasr/asr_verify.h>
#include <libasr/pass/dead_code_removal.h>
#include <libasr/pass/pass_utils.h>

#include <vector>
#include <map>
#include <utility>


namespace LFortran {

using ASR::down_cast;
using ASR::is_a;


class DeadCodeRemovalVisitor : public PassUtils::PassVisitor<DeadCodeRemovalVisitor>
{
private:

std::string rl_path;

public:

bool dead_code_removed;

DeadCodeRemovalVisitor(Allocator &al_, const std::string& rl_path_) : PassVisitor(al_, nullptr),
rl_path(rl_path_), dead_code_removed(false)
{
pass_result.reserve(al, 1);
}

void visit_If(const ASR::If_t& x) {
ASR::If_t& xx = const_cast<ASR::If_t&>(x);
transform_stmts(xx.m_body, xx.n_body);
transform_stmts(xx.m_orelse, xx.n_orelse);
ASR::expr_t* m_test_value = ASRUtils::expr_value(x.m_test);
bool m_test_bool;
if( ASRUtils::is_value_constant(m_test_value, m_test_bool) ) {
ASR::stmt_t** selected_part = nullptr;
size_t n_selected_part = 0;
if( m_test_bool ) {
selected_part = x.m_body;
n_selected_part = x.n_body;
} else {
selected_part = x.m_orelse;
n_selected_part = x.n_orelse;
}
for( size_t i = 0; i < n_selected_part; i++ ) {
pass_result.push_back(al, selected_part[i]);
}
dead_code_removed = true;
}
}

void visit_Select(const ASR::Select_t& x) {
ASR::Select_t& xx = const_cast<ASR::Select_t&>(x);
ASR::expr_t* m_test_value = ASRUtils::expr_value(x.m_test);
if( !ASRUtils::is_value_constant(m_test_value) ) {
return ;
}

for( size_t i = 0; i < x.n_body; i++ ) {
ASR::case_stmt_t* case_body = x.m_body[i];
switch (case_body->type) {
case ASR::case_stmtType::CaseStmt: {
ASR::CaseStmt_t* casestmt = ASR::down_cast<ASR::CaseStmt_t>(case_body);
transform_stmts(casestmt->m_body, casestmt->n_body);
xx.m_body[i] = (ASR::case_stmt_t*)(&(casestmt->base));
for( size_t j = 0; j < casestmt->n_test; j++ ) {
if( ASRUtils::is_value_equal(casestmt->m_test[j], x.m_test) ) {
for( size_t k = 0; k < casestmt->n_body; k++ ) {
pass_result.push_back(al, casestmt->m_body[k]);
}
return ;
}
}
break;
}
case ASR::case_stmtType::CaseStmt_Range: {
ASR::CaseStmt_Range_t* casestmt_range = ASR::down_cast<ASR::CaseStmt_Range_t>(case_body);
transform_stmts(casestmt_range->m_body, casestmt_range->n_body);
xx.m_body[i] = (ASR::case_stmt_t*)(&(casestmt_range->base));
if( ASRUtils::is_value_in_range(casestmt_range->m_start, casestmt_range->m_end, x.m_test) ) {
for( size_t k = 0; k < casestmt_range->n_body; k++ ) {
pass_result.push_back(al, casestmt_range->m_body[k]);
}
return ;
}
break;
}
}
}
}

};

void pass_dead_code_removal(Allocator &al, ASR::TranslationUnit_t &unit,
const std::string& rl_path) {
DeadCodeRemovalVisitor v(al, rl_path);
v.visit_TranslationUnit(unit);
LFORTRAN_ASSERT(asr_verify(unit));
}


} // namespace LFortran
12 changes: 12 additions & 0 deletions src/libasr/pass/dead_code_removal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef LIBASR_PASS_DEAD_CODE_REMOVAL_H
#define LIBASR_PASS_DEAD_CODE_REMOVAL_H

#include <libasr/asr.h>

namespace LFortran {

void pass_dead_code_removal(Allocator &al, ASR::TranslationUnit_t &unit, const std::string& rl_path);

} // namespace LFortran

#endif

0 comments on commit daa19a4

Please sign in to comment.