Skip to content

Commit

Permalink
Revamp arr_slice pass by using latest replacer APIs in ``asdl_cpp…
Browse files Browse the repository at this point in the history
….py`` (lcompilers#1530)
  • Loading branch information
czgdp1807 committed Feb 17, 2023
1 parent 5e9e730 commit 20123ad
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 193 deletions.
31 changes: 30 additions & 1 deletion src/libasr/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,14 @@ def visitModule(self, mod):
self.emit(" Struct& self() { return static_cast<Struct&>(*this); }")
self.emit("public:")
self.emit(" ASR::expr_t** current_expr;")
self.emit(" SymbolTable* current_scope;")
self.emit("")
self.emit(" void call_replacer() {}")
self.emit(" void transform_stmts(ASR::stmt_t **&m_body, size_t &n_body) {")
self.emit(" for (size_t i = 0; i < n_body; i++) {", 1)
self.emit(" self().visit_stmt(*m_body[i]);", 1)
self.emit(" }", 1)
self.emit(" }")
super(CallReplacerOnExpressionsVisitor, self).visitModule(mod)
self.emit("};")

Expand All @@ -440,14 +446,34 @@ def visitConstructor(self, cons, _):

def make_visitor(self, name, fields):
self.emit("void visit_%s(const %s_t &x) {" % (name, name), 1)
is_symtab_present = False
is_stmt_present = False
symtab_field_name = ""
for field in fields:
if field.type == "stmt":
is_stmt_present = True
if field.type == "symbol_table":
is_symtab_present = True
symtab_field_name = field.name
if is_stmt_present and is_symtab_present:
break
if is_stmt_present and name not in ("Assignment", "ForAllSingle"):
self.emit(" %s_t& xx = const_cast<%s_t&>(x);" % (name, name), 1)
self.used = False
have_body = False

if is_symtab_present:
self.emit("SymbolTable* current_scope_copy = current_scope;", 2)
self.emit("current_scope = x.m_%s;" % symtab_field_name, 2)

for field in fields:
self.visitField(field)
if not self.used:
# Note: a better solution would be to change `&x` to `& /* x */`
# above, but we would need to change emit to return a string.
self.emit("if ((bool&)x) { } // Suppress unused warning", 2)

if is_symtab_present:
self.emit("current_scope = current_scope_copy;", 2)
self.emit("}", 1)

def insert_call_replacer_code(self, name, level, index=""):
Expand All @@ -462,6 +488,9 @@ def visitField(self, field):
field.type not in self.data.simple_types):
level = 2
if field.seq:
if field.type == "stmt":
self.emit("self().transform_stmts(xx.m_%s, xx.n_%s);" % (field.name, field.name), level)
return
self.used = True
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
if field.type in products:
Expand Down
Loading

0 comments on commit 20123ad

Please sign in to comment.