Skip to content

Commit

Permalink
Sync libasr from LFortran
Browse files Browse the repository at this point in the history
Co-authored-by: Shaikh Ubaid <shaikhubaid769@gmail.com>
  • Loading branch information
czgdp1807 and Shaikh-Ubaid committed Apr 7, 2023
1 parent b35e744 commit 35a8d86
Show file tree
Hide file tree
Showing 32 changed files with 2,129 additions and 431 deletions.
21 changes: 12 additions & 9 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ symbol
symbol external, identifier module_name, identifier* scope_names,
identifier original_name, access access)
| StructType(symbol_table symtab, identifier name, identifier* dependencies,
identifier* members, abi abi, access access, bool is_packed,
identifier* members, abi abi, access access, bool is_packed, bool is_abstract,
expr? alignment, symbol? parent)
| EnumType(symbol_table symtab, identifier name, identifier* dependencies,
identifier* members, abi abi, access access, enumtype enum_value_type,
Expand All @@ -107,7 +107,7 @@ symbol
abi abi, access access, presence presence, bool value_attr)
| ClassType(symbol_table symtab, identifier name, abi abi, access access)
| ClassProcedure(symbol_table parent_symtab, identifier name, identifier? self_argument,
identifier proc_name, symbol proc, abi abi)
identifier proc_name, symbol proc, abi abi, bool is_deferred)
| AssociateBlock(symbol_table symtab, identifier name, stmt* body)
| Block(symbol_table symtab, identifier name, stmt* body)

Expand Down Expand Up @@ -161,15 +161,15 @@ stmt
| Assign(int label, identifier variable)
| Assignment(expr target, expr value, stmt? overloaded)
| Associate(expr target, expr value)
| Cycle()
| Cycle(identifier? stmt_name)
-- deallocates if allocated otherwise throws a runtime error
| ExplicitDeallocate(expr* vars)
-- deallocates if allocated otherwise does nothing
| ImplicitDeallocate(symbol* vars)
| ImplicitDeallocate(expr* vars)
| DoConcurrentLoop(do_loop_head head, stmt* body)
| DoLoop(do_loop_head head, stmt* body)
| DoLoop(identifier? name, do_loop_head head, stmt* body)
| ErrorStop(expr? code)
| Exit()
| Exit(identifier? stmt_name)
| ForAllSingle(do_loop_head head, stmt assign_stmt)
-- GoTo points to a GoToTarget with the corresponding target_id within
-- the same procedure. We currently use `int` IDs to link GoTo with
Expand Down Expand Up @@ -201,12 +201,12 @@ stmt
| Assert(expr test, expr? msg)
| SubroutineCall(symbol name, symbol? original_name, call_arg* args, expr? dt)
| Where(expr test, stmt* body, stmt* orelse)
| WhileLoop(expr test, stmt* body)
| WhileLoop(identifier? name, expr test, stmt* body)
| Nullify(symbol* vars)
| Flush(int label, expr unit, expr? err, expr? iomsg, expr? iostat)
| ListAppend(expr a, expr ele)
| AssociateBlockCall(symbol m)
| SelectType(type_stmt* body, stmt* default)
| SelectType(expr selector, type_stmt* body, stmt* default)
| CPtrToPointer(expr cptr, expr ptr, expr? shape)
| BlockCall(int label, symbol m)
| SetInsert(expr a, expr ele)
Expand Down Expand Up @@ -420,7 +420,10 @@ do_loop_head = (expr? v, expr? start, expr? end, expr? increment)

case_stmt = CaseStmt(expr* test, stmt* body) | CaseStmt_Range(expr? start, expr? end, stmt* body)

type_stmt = TypeStmtName(symbol sym, stmt* body) | TypeStmtType(ttype type, stmt* body)
type_stmt
= TypeStmtName(symbol sym, stmt* body)
| ClassStmt(symbol sym, stmt* body)
| TypeStmtType(ttype type, stmt* body)

enumtype = IntegerConsecutiveFromZero | IntegerUnique | IntegerNotUnique | NonInteger

Expand Down
123 changes: 117 additions & 6 deletions src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,111 @@ def visitField(self, field):
self.emit( "this->visit_symbol(*a.second);", 3)
self.emit("}", 2)

class ASRPassWalkVisitorVisitor(ASDLVisitor):

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Walk Visitor base class")
self.emit("")
self.emit("template <class Struct>")
self.emit("class ASRPassBaseWalkVisitor : public BaseVisitor<Struct>")
self.emit("{")
self.emit("private:")
self.emit(" Struct& self() { return static_cast<Struct&>(*this); }")
self.emit("public:")
self.emit(" SymbolTable* current_scope;")
self.emit(" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {")
self.emit(" for (size_t i = 0; i < n_body; i++) {", 1)
self.emit(" self().visit_stmt(*m_body[i]);", 1)
self.emit(" }", 1)
self.emit("}", 1)
super(ASRPassWalkVisitorVisitor, self).visitModule(mod)
self.emit("};")

def visitType(self, tp):
if not (isinstance(tp.value, asdl.Sum) and
is_simple_sum(tp.value)):
super(ASRPassWalkVisitorVisitor, self).visitType(tp, tp.name)

def visitProduct(self, prod, name):
self.make_visitor(name, prod.fields)

def visitConstructor(self, cons, _):
self.make_visitor(cons.name, cons.fields)

def make_visitor(self, name, fields):
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
is_symtab_present = False
is_stmt_present = False
symtab_field_name = ""
for field in fields:
if field.type == "stmt":
is_stmt_present = True
if field.type == "symbol_table":
is_symtab_present = True
symtab_field_name = field.name
if is_stmt_present and is_symtab_present:
break
if is_stmt_present and name not in ("Assignment", "ForAllSingle"):
self.emit(" %s_t& xx = const_cast<%s_t&>(x);" % (name, name), 1)
self.used = False

if is_symtab_present:
self.emit("SymbolTable* current_scope_copy = current_scope;", 2)
self.emit("current_scope = x.m_%s;" % symtab_field_name, 2)

for field in fields:
self.visitField(field)
if not self.used:
# Note: a better solution would be to change `&x` to `& /* x */`
# above, but we would need to change emit to return a string.
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)

if is_symtab_present:
self.emit("current_scope = current_scope_copy;", 2)

self.emit("}", 1)

def visitField(self, field):
if (field.type not in asdl.builtin_types and
field.type not in self.data.simple_types):
level = 2
if field.seq:
if field.type == "stmt":
self.emit("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field.name, field.name), level)
return
self.used = True
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
if field.type in products:
self.emit(" self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level)
else:
if field.type != "symbol":
self.emit(" self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level)
self.emit("}", level)
else:
if field.type in products:
self.used = True
if field.opt:
self.emit("if (x.m_%s)" % field.name, 2)
level = 3
if field.opt:
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
else:
self.emit("self().visit_%s(x.m_%s);" % (field.type, field.name), level)
else:
if field.type != "symbol":
self.used = True
if field.opt:
self.emit("if (x.m_%s)" % field.name, 2)
level = 3
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
elif field.type == "symbol_table" and field.name in["symtab",
"global_scope"]:
self.used = True
self.emit("for (auto &a : x.m_%s->get_scope()) {" % field.name, 2)
self.emit( "this->visit_symbol(*a.second);", 3)
self.emit("}", 2)

class CallReplacerOnExpressionsVisitor(ASDLVisitor):

def __init__(self, stream, data):
Expand Down Expand Up @@ -477,10 +582,10 @@ def make_visitor(self, name, fields):
self.emit("}", 1)

def insert_call_replacer_code(self, name, level, index=""):
self.emit(" ASR::expr_t** current_expr_copy_%d = current_expr;" % (self.current_expr_copy_variable_count), level)
self.emit(" current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name, index), level)
self.emit(" self().call_replacer();", level)
self.emit(" current_expr = current_expr_copy_%d;" % (self.current_expr_copy_variable_count), level)
self.emit("ASR::expr_t** current_expr_copy_%d = current_expr;" % (self.current_expr_copy_variable_count), level)
self.emit("current_expr = const_cast<ASR::expr_t**>(&(x.m_%s%s));" % (name, index), level)
self.emit("self().call_replacer();", level)
self.emit("current_expr = current_expr_copy_%d;" % (self.current_expr_copy_variable_count), level)
self.current_expr_copy_variable_count += 1

def visitField(self, field):
Expand All @@ -495,12 +600,14 @@ def visitField(self, field):
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
if field.type in products:
if field.type == "expr":
self.insert_call_replacer_code(field.name, level, "[i]")
self.insert_call_replacer_code(field.name, level + 1, "[i]")
self.emit("if( x.m_%s[i] )" % (field.name), level)
self.emit(" self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level)
else:
if field.type != "symbol":
if field.type == "expr":
self.insert_call_replacer_code(field.name, level, "[i]")
self.insert_call_replacer_code(field.name, level + 1, "[i]")
self.emit("if( x.m_%s[i] )" % (field.name), level + 1)
self.emit(" self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level)
self.emit("}", level)
else:
Expand All @@ -511,6 +618,7 @@ def visitField(self, field):
level = 3
if field.type == "expr":
self.insert_call_replacer_code(field.name, level)
self.emit("if( x.m_%s )" % (field.name), level)
if field.opt:
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
self.emit("}", 2)
Expand All @@ -524,6 +632,7 @@ def visitField(self, field):
level = 3
if field.type == "expr":
self.insert_call_replacer_code(field.name, level)
self.emit("if( x.m_%s )" % (field.name), level)
self.emit("self().visit_%s(*x.m_%s);" % (field.type, field.name), level)
if field.opt:
self.emit("}", 2)
Expand Down Expand Up @@ -2595,6 +2704,8 @@ def main(argv):

try:
if is_asr:
ASRPassWalkVisitorVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprStmtDuplicatorVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprBaseReplacerVisitor(fp, data).visit(mod)
Expand Down
14 changes: 14 additions & 0 deletions src/libasr/asr_scopes.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,21 @@ struct SymbolTable {
scope.erase(name);
}

// Add a new symbol that did not exist before
void add_symbol(const std::string &name, ASR::symbol_t* symbol) {
LCOMPILERS_ASSERT(scope.find(name) == scope.end())
scope[name] = symbol;
}

// Overwrite an existing symbol
void overwrite_symbol(const std::string &name, ASR::symbol_t* symbol) {
LCOMPILERS_ASSERT(scope.find(name) != scope.end())
scope[name] = symbol;
}

// Use as the last resort, prefer to always either add a new symbol
// or overwrite an existing one, not both
void add_or_overwrite_symbol(const std::string &name, ASR::symbol_t* symbol) {
scope[name] = symbol;
}

Expand Down
Loading

0 comments on commit 35a8d86

Please sign in to comment.