Skip to content

Commit

Permalink
ASR: update from LFortran (lcompilers#874)
Browse files Browse the repository at this point in the history
Co-authored-by: Gagandeep Singh <gdp.1807@gmail.com>
  • Loading branch information
certik and czgdp1807 authored Aug 3, 2022
1 parent 25416f4 commit 7210694
Show file tree
Hide file tree
Showing 57 changed files with 2,447 additions and 491 deletions.
131 changes: 130 additions & 1 deletion grammar/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,129 @@ def visitField(self, field):
self.emit( "this->visit_symbol(*a.second);", 3)
self.emit("}", 2)

class CallReplacerOnExpressionsVisitor(ASDLVisitor):

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Walk Visitor base class")
self.emit("")
self.emit("template <class Derived>")
self.emit("class CallReplacerOnExpressionsVisitor : public BaseVisitor<Derived>")
self.emit("{")
self.emit("private:")
self.emit(" Derived& self() { return static_cast<Derived&>(*this); }")
self.emit("public:")
self.emit(" ASR::expr_t** current_expr;")
self.emit(" ASR::expr_t** current_expr_copy;")
self.emit("")
self.emit(" void call_replacer() {}")
super(CallReplacerOnExpressionsVisitor, self).visitModule(mod)
self.emit("};")

def visitType(self, tp):
if not (isinstance(tp.value, asdl.Sum) and
is_simple_sum(tp.value)):
super(CallReplacerOnExpressionsVisitor, self).visitType(tp, tp.name)

def visitProduct(self, prod, name):
self.make_visitor(name, prod.fields)

def visitConstructor(self, cons, _):
self.make_visitor(cons.name, cons.fields)

def make_visitor(self, name, fields):
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
self.used = False
have_body = False
for field in fields:
self.visitField(field)
if not self.used:
# Note: a better solution would be to change `&x` to `& /* x */`
# above, but we would need to change emit to return a string.
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)
self.emit("}", 1)

def insert_call_replacer_code(self, name, level, index=""):
self.emit(" current_expr_copy = current_expr;", level)
self.emit(" current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name, index), level)
self.emit(" self().call_replacer();", level)
self.emit(" current_expr = current_expr_copy;", level)

def visitField(self, field):
if (field.type not in asdl.builtin_types and
field.type not in self.data.simple_types):
level = 2
if field.seq:
self.used = True
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
if field.type in products:
if field.type == "expr":
self.insert_call_replacer_code(field.name, level, "[i]")
self.emit(" self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level)
else:
if field.type != "symbol":
if field.type == "expr":
self.insert_call_replacer_code(field.name, level, "[i]")
self.emit(" self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level)
self.emit("}", level)
else:
if field.type in products:
self.used = True
if field.opt:
self.emit("if (x.m_%s) {" % field.name, 2)
level = 3
if field.type == "expr":
self.insert_call_replacer_code(field.name, level)
if field.opt:
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
self.emit("}", 2)
else:
self.emit("self().visit_%s(x.m_%s);" % (field.type, field.name), level)
else:
if field.type != "symbol":
self.used = True
if field.opt:
self.emit("if (x.m_%s) {" % field.name, 2)
level = 3
if field.type == "expr":
self.insert_call_replacer_code(field.name, level)
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
if field.opt:
self.emit("}", 2)
elif field.type == "symbol_table" and field.name in["symtab",
"global_scope"]:
self.used = True
self.emit("for (auto &a : x.m_%s->get_scope()) {" % field.name, 2)
self.emit( "this->visit_symbol(*a.second);", 3)
self.emit("}", 2)

class StatementsFirstWalkVisitorVisitor(ASTWalkVisitorVisitor, ASDLVisitor):

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Statements First Visitor base class")
self.emit("")
self.emit("template <class Derived>")
self.emit("class StatementsFirstBaseWalkVisitor : public BaseVisitor<Derived>")
self.emit("{")
self.emit("private:")
self.emit(" Derived& self() { return static_cast<Derived&>(*this); }")
self.emit("public:")
super(ASTWalkVisitorVisitor, self).visitModule(mod)
self.emit("};")

def make_visitor(self, name, fields):
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
self.used = False
have_body = False
for field in fields[::-1]:
self.visitField(field)
if not self.used:
# Note: a better solution would be to change `&x` to `& /* x */`
# above, but we would need to change emit to return a string.
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)
self.emit("}", 1)

# This class generates a visitor that prints the tree structure of AST/ASR
class TreeVisitorVisitor(ASDLVisitor):

Expand Down Expand Up @@ -1822,6 +1945,9 @@ def make_visitor(self, name, fields):
return ASR::down_cast<ASR::Variable_t>(s)->m_type;
}""" \
% (name, name), 2, new_line=False)
elif name == "OverloadedBinOp":
self.emit("case ASR::exprType::%s: { return expr_type0(((ASR::%s_t*)f)->m_overloaded); }"\
% (name, name), 2, new_line=False)
else:
self.emit("case ASR::exprType::%s: { return ((ASR::%s_t*)f)->m_type; }"\
% (name, name), 2, new_line=False)
Expand Down Expand Up @@ -2010,7 +2136,8 @@ def add_masks(fields, node):
visitors = [ASTNodeVisitor0, ASTNodeVisitor1, ASTNodeVisitor,
ASTVisitorVisitor1, ASTVisitorVisitor1b, ASTVisitorVisitor2,
ASTWalkVisitorVisitor, TreeVisitorVisitor, PickleVisitorVisitor,
SerializationVisitorVisitor, DeserializationVisitorVisitor]
StatementsFirstWalkVisitorVisitor, SerializationVisitorVisitor,
DeserializationVisitorVisitor]


def main(argv):
Expand Down Expand Up @@ -2059,6 +2186,8 @@ def main(argv):
fp.write("\n\n")
StmtBaseReplacerVisitor(fp, data).visit(mod)
fp.write("\n\n")
CallReplacerOnExpressionsVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprTypeVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprValueVisitor(fp, data).visit(mod)
Expand Down
6 changes: 4 additions & 2 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ expr
| ListLen(expr arg, ttype type, expr? value)
| ListConcat(expr left, expr right, ttype type, expr? value)

| ArrayConstant(expr* args, ttype type)

| SetConstant(expr* elements, ttype type)
| SetLen(expr arg, ttype type, expr? value)

Expand All @@ -264,7 +262,10 @@ expr

| DictConstant(expr* keys, expr* values, ttype type)
| DictLen(expr arg, ttype type, expr? value)

| Var(symbol v)

| ArrayConstant(expr* args, ttype type)
| ArrayItem(expr v, array_index* args, ttype type, expr? value)
| ArraySection(expr v, array_index* args, ttype type, expr? value)
| ArraySize(expr v, expr? dim, ttype type, expr? value)
Expand All @@ -274,6 +275,7 @@ expr
| ArrayMatMul(expr matrix_a, expr matrix_b, ttype type, expr? value)
| ArrayPack(expr array, expr mask, expr? vector, ttype type, expr? value)
| ArrayReshape(expr array, expr shape, ttype type, expr? value)

| BitCast(expr source, expr mold, expr? size, ttype type, expr? value)
| DerivedRef(expr v, symbol m, ttype type, expr? value)
| OverloadedCompare(expr left, cmpop op, expr right, ttype type, expr? value, expr overloaded)
Expand Down
5 changes: 3 additions & 2 deletions src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ set(SRC
codegen/asr_to_wasm.cpp
codegen/wasm_to_wat.cpp
codegen/wasm_utils.cpp

pass/param_to_const.cpp
pass/do_loops.cpp
pass/for_all.cpp
Expand All @@ -33,6 +33,7 @@ set(SRC
pass/implied_do_loops.cpp
pass/array_op.cpp
pass/class_constructor.cpp
pass/arr_dims_propagate.cpp
pass/arr_slice.cpp
pass/print_arr.cpp
pass/pass_utils.cpp
Expand All @@ -45,6 +46,7 @@ set(SRC
pass/inline_function_calls.cpp
pass/loop_unroll.cpp
pass/dead_code_removal.cpp
pass/update_array_dim_intrinsic_calls.cpp

asr_verify.cpp
asr_utils.cpp
Expand Down Expand Up @@ -76,7 +78,6 @@ if (WITH_LLVM)
COMPILE_FLAGS -Wno-deprecated-declarations)
endif()
endif()

add_library(asr ${SRC})
target_include_directories(asr BEFORE PUBLIC ${libasr_SOURCE_DIR}/..)
target_include_directories(asr BEFORE PUBLIC ${libasr_BINARY_DIR}/..)
Expand Down
23 changes: 20 additions & 3 deletions src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,15 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
if( a_name == nullptr ) {
err("Unable to resolve matched function for operator overloading, " + matched_func_name, loc);
}
ASR::ttype_t *return_type = nullptr;
if( func->m_elemental && func->n_args == 1 && ASRUtils::is_array(ASRUtils::expr_type(a_args[0].m_value)) ) {
return_type = ASRUtils::duplicate_type(al, ASRUtils::expr_type(a_args[0].m_value));
} else {
return_type = ASRUtils::expr_type(func->m_return_var);
}
asr = ASR::make_FunctionCall_t(al, loc, a_name, sym,
a_args.p, 2,
ASRUtils::expr_type(func->m_return_var),
return_type,
nullptr, nullptr);
}
}
Expand Down Expand Up @@ -513,9 +519,15 @@ bool use_overloaded(ASR::expr_t* left, ASR::expr_t* right,
if( a_name == nullptr ) {
err("Unable to resolve matched function for operator overloading, " + matched_func_name, loc);
}
ASR::ttype_t *return_type = nullptr;
if( func->m_elemental && func->n_args == 1 && ASRUtils::is_array(ASRUtils::expr_type(a_args[0].m_value)) ) {
return_type = ASRUtils::duplicate_type(al, ASRUtils::expr_type(a_args[0].m_value));
} else {
return_type = ASRUtils::expr_type(func->m_return_var);
}
asr = ASR::make_FunctionCall_t(al, loc, a_name, sym,
a_args.p, 2,
ASRUtils::expr_type(func->m_return_var),
return_type,
nullptr, nullptr);
}
}
Expand Down Expand Up @@ -810,7 +822,12 @@ ASR::asr_t* symbol_resolve_external_generic_procedure_without_eval(
bool is_subroutine = ASR::is_a<ASR::Subroutine_t>(*final_sym);
ASR::ttype_t *return_type = nullptr;
if( ASR::is_a<ASR::Function_t>(*final_sym) ) {
return_type = LFortran::ASRUtils::EXPR2VAR(ASR::down_cast<ASR::Function_t>(final_sym)->m_return_var)->m_type;
ASR::Function_t* func = ASR::down_cast<ASR::Function_t>(final_sym);
if( func->m_elemental && func->n_args == 1 && ASRUtils::is_array(ASRUtils::expr_type(args[0].m_value)) ) {
return_type = ASRUtils::duplicate_type(al, ASRUtils::expr_type(args[0].m_value));
} else {
return_type = LFortran::ASRUtils::EXPR2VAR(func->m_return_var)->m_type;
}
}
// Create ExternalSymbol for the final subroutine:
// We mangle the new ExternalSymbol's local name as:
Expand Down
40 changes: 27 additions & 13 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,20 @@ ASR::asr_t* symbol_resolve_external_generic_procedure_without_eval(
SymbolTable* current_scope, Allocator& al,
const std::function<void (const std::string &, const Location &)> err);

static inline bool is_dimension_empty(ASR::dimension_t& dim) {
return ((dim.m_length == nullptr) ||
(dim.m_start == nullptr));
}

static inline bool is_dimension_empty(ASR::dimension_t* dims, size_t n) {
for( size_t i = 0; i < n; i++ ) {
if( is_dimension_empty(dims[i]) ) {
return true;
}
}
return false;
}

class ReplaceArgVisitor: public ASR::BaseExprReplacer<ReplaceArgVisitor> {

private:
Expand Down Expand Up @@ -1506,39 +1520,39 @@ static inline ASR::expr_t* compute_end_from_start_length(Allocator& al, ASR::exp
ASRUtils::extract_value(start_value, start_int);
ASRUtils::extract_value(length_value, length_int);
end_value = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, start->base.loc,
length_int + start_int - 1,
ASRUtils::expr_type(start)));
length_int + start_int - 1,
ASRUtils::expr_type(start)));
}
ASR::expr_t* diff = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, length->base.loc, length,
ASR::binopType::Add, start, ASRUtils::expr_type(length),
nullptr));
ASR::binopType::Add, start, ASRUtils::expr_type(length),
nullptr));
ASR::expr_t *constant_one = ASR::down_cast<ASR::expr_t>(ASR::make_IntegerConstant_t(
al, diff->base.loc, 1, ASRUtils::expr_type(diff)));
return ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, length->base.loc, diff,
ASR::binopType::Sub, constant_one, ASRUtils::expr_type(length),
end_value));
ASR::binopType::Sub, constant_one, ASRUtils::expr_type(length),
end_value));
}

static inline ASR::expr_t* compute_length_from_start_end(Allocator& al, ASR::expr_t* start, ASR::expr_t* end) {
ASR::expr_t* start_value = ASRUtils::expr_value(start);
ASR::expr_t* end_value = ASRUtils::expr_value(end);
ASR::expr_t* length_value = nullptr;
if( start_value && end_value ) {
int64_t start_int, end_int;
int64_t start_int = -1, end_int = -1;
ASRUtils::extract_value(start_value, start_int);
ASRUtils::extract_value(end_value, end_int);
length_value = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, start->base.loc,
end_int - start_int + 1,
ASRUtils::expr_type(start)));
end_int - start_int + 1,
ASRUtils::expr_type(start)));
}
ASR::expr_t* diff = ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, end->base.loc, end,
ASR::binopType::Sub, start, ASRUtils::expr_type(end),
nullptr));
ASR::binopType::Sub, start, ASRUtils::expr_type(end),
nullptr));
ASR::expr_t *constant_one = ASR::down_cast<ASR::expr_t>(ASR::make_IntegerConstant_t(
al, diff->base.loc, 1, ASRUtils::expr_type(diff)));
return ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, end->base.loc, diff,
ASR::binopType::Add, constant_one, ASRUtils::expr_type(end),
length_value));
ASR::binopType::Add, constant_one, ASRUtils::expr_type(end),
length_value));
}

} // namespace ASRUtils
Expand Down
5 changes: 2 additions & 3 deletions src/libasr/codegen/asr_to_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class ASRToCPPVisitor : public BaseCCPPVisitor<ASRToCPPVisitor>
v.m_intent != ASRUtils::intent_out, true);
}
} else {
sub = format_type(dims, "int *", v.m_name, use_ref, dummy);
sub = format_type(dims, type_name, v.m_name, use_ref, dummy);
}
} else {
diag.codegen_error_label("Type number '"
Expand All @@ -240,8 +240,7 @@ class ASRToCPPVisitor : public BaseCCPPVisitor<ASRToCPPVisitor>
ASR::Integer_t *t = ASR::down_cast<ASR::Integer_t>(v.m_type);
size_t size;
dims = convert_dims(t->n_dims, t->m_dims, size);
std::string type_name = "int";
if (t->m_kind == 8) type_name = "long long";
std::string type_name = "int" + std::to_string(t->m_kind * 8) + "_t";
if( is_array ) {
if( use_templates_for_arrays ) {
sub += generate_templates_for_arrays(std::string(v.m_name));
Expand Down
Loading

0 comments on commit 7210694

Please sign in to comment.