Skip to content

Commit

Permalink
asdl_cpp: add a new visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
certik committed Mar 19, 2022
1 parent 48ac8a7 commit ba0be44
Showing 1 changed file with 160 additions and 1 deletion.
161 changes: 160 additions & 1 deletion grammar/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,154 @@ def visitField(self, field):
self.emit("}", 2)


class ExprStmtDuplicatorVisitor(ASDLVisitor):

def __init__(self, stream, data):
self.duplicate_stmt = []
self.duplicate_expr = []
self.is_stmt = False
self.is_expr = False
self.is_product = False
super(ExprStmtDuplicatorVisitor, self).__init__(stream, data)

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Expression and statement Duplicator class")
self.emit("")
self.emit("class ExprStmtDuplicator {")
self.emit("private:")
self.emit(" Allocator& al;")
self.emit("")
self.emit("public:")
self.emit(" bool success;")
self.emit("")
self.emit(" ExprStmtDuplicator(Allocator& al_) : al(al_), success(false) {}")
self.emit("")
self.duplicate_stmt.append((" ASR::stmt_t* duplicate_stmt(ASR::stmt_t* x) {", 0))
self.duplicate_stmt.append((" if( !x ) {", 1))
self.duplicate_stmt.append((" return nullptr;", 2))
self.duplicate_stmt.append((" }", 1))
self.duplicate_stmt.append(("", 0))
self.duplicate_stmt.append((" switch(x->type) {", 1))

self.duplicate_expr.append((" ASR::expr_t* duplicate_expr(ASR::expr_t* x) {", 0))
self.duplicate_expr.append((" if( !x ) {", 1))
self.duplicate_expr.append((" return nullptr;", 2))
self.duplicate_expr.append((" }", 1))
self.duplicate_expr.append(("", 0))
self.duplicate_expr.append((" switch(x->type) {", 1))

super(ExprStmtDuplicatorVisitor, self).visitModule(mod)
self.duplicate_stmt.append((" default: {", 2))
self.duplicate_stmt.append((' LFORTRAN_ASSERT_MSG(false, "Duplication of " + std::to_string(x->type) + " statement is not supported yet.");', 3))
self.duplicate_stmt.append((" }", 2))
self.duplicate_stmt.append((" }", 1))
self.duplicate_stmt.append(("", 0))
self.duplicate_stmt.append((" return nullptr;", 1))
self.duplicate_stmt.append((" }", 0))

self.duplicate_expr.append((" default: {", 2))
self.duplicate_expr.append((' LFORTRAN_ASSERT_MSG(false, "Duplication of " + std::to_string(x->type) + " expression is not supported yet.");', 3))
self.duplicate_expr.append((" }", 2))
self.duplicate_expr.append((" }", 1))
self.duplicate_expr.append(("", 0))
self.duplicate_expr.append((" return nullptr;", 1))
self.duplicate_expr.append((" }", 0))
for line, level in self.duplicate_stmt:
self.emit(line, level=level)
self.emit("")
for line, level in self.duplicate_expr:
self.emit(line, level=level)
self.emit("")
self.emit("};")

def visitType(self, tp):
if not (isinstance(tp.value, asdl.Sum) and
is_simple_sum(tp.value)):
super(ExprStmtDuplicatorVisitor, self).visitType(tp, tp.name)

def visitSum(self, sum, *args):
self.is_stmt = args[0] == 'stmt'
self.is_expr = args[0] == 'expr'
if self.is_stmt or self.is_expr:
for tp in sum.types:
self.visit(tp, *args)

def visitProduct(self, prod, name):
pass

def visitConstructor(self, cons, _):
self.make_visitor(cons.name, cons.fields)

def make_visitor(self, name, fields):
if name == "FunctionCall" or name == "SubroutineCall":
if self.is_stmt:
self.duplicate_stmt.append((" case ASR::stmtType::%s: {" % name, 2))
self.duplicate_stmt.append((" success = false;", 3))
self.duplicate_stmt.append((" return nullptr;", 3))
self.duplicate_stmt.append((" }", 2))
elif self.is_expr:
self.duplicate_expr.append((" case ASR::exprType::%s: {" % name, 2))
self.duplicate_expr.append((" success = false;", 3))
self.duplicate_expr.append((" return nullptr;", 3))
self.duplicate_expr.append((" }", 2))
return None

self.emit("")
self.emit("ASR::asr_t* duplicate_%s(%s_t* x) {" % (name, name), 1)
self.used = False
arguments = []
for field in fields:
ret_value = self.visitField(field)
for node_arg in ret_value:
arguments.append(node_arg)
if not self.used:
self.emit("return (asr_t*)x;", 2)
else:
node_arg_str = ', '.join(arguments)
self.emit("return make_%s_t(al, x->base.base.loc, %s);" %(name, node_arg_str), 2)
if self.is_stmt:
self.duplicate_stmt.append((" case ASR::stmtType::%s: {" % name, 2))
self.duplicate_stmt.append((" return down_cast<ASR::stmt_t>(duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name, name), 3))
self.duplicate_stmt.append((" }", 2))
elif self.is_expr:
self.duplicate_expr.append((" case ASR::exprType::%s: {" % name, 2))
self.duplicate_expr.append((" return down_cast<ASR::expr_t>(duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name, name), 3))
self.duplicate_expr.append((" }", 2))
self.emit("}", 1)
self.emit("")

def visitField(self, field):
arguments = None
if field.type == "expr" or field.type == "stmt" or field.type == "symbol":
level = 2
if field.seq:
self.used = True
self.emit("Vec<%s_t*> m_%s;" % (field.type, field.name), level)
self.emit("m_%s.reserve(al, x->n_%s);" % (field.name, field.name), level)
self.emit("for (size_t i = 0; i < x->n_%s; i++) {" % field.name, level)
if field.type == "symbol":
self.emit(" m_%s.push_back(al, x->m_%s[i]);" % (field.name, field.name), level)
else:
self.emit(" m_%s.push_back(al, duplicate_%s(x->m_%s[i]));" % (field.name, field.type, field.name), level)
self.emit("}", level)
arguments = ("m_" + field.name + ".p", "x->n_" + field.name)
else:
self.used = True
if field.type == "symbol":
self.emit("%s_t* m_%s = x->m_%s;" % (field.type, field.name, field.name), level)
else:
self.emit("%s_t* m_%s = duplicate_%s(x->m_%s);" % (field.type, field.name, field.type, field.name), level)
arguments = ("m_" + field.name, )
else:
if field.seq:
arguments = ("x->m_" + field.name, "x->n_" + field.name)
else:
arguments = ("x->m_" + field.name, )
return arguments



class PickleVisitorVisitor(ASDLVisitor):

def visitModule(self, mod):
Expand Down Expand Up @@ -1296,13 +1444,24 @@ def main(argv):
if subs["MOD"] == "PYTHON":
subs["MOD"] = "Python::AST"
subs["mod"] = "ast"
is_asr = (mod.name.upper() == "ASR")
fp = open(out_file, "w")
try:
fp.write(HEAD % subs)
for visitor in visitors:
visitor(fp, data).visit(mod)
fp.write("\n\n")
fp.write(FOOT % subs)
if not is_asr:
fp.write(FOOT % subs)
finally:
if not is_asr:
fp.close()

try:
if is_asr:
ExprStmtDuplicatorVisitor(fp, data).visit(mod)
fp.write("\n\n")
fp.write(FOOT % subs)
finally:
fp.close()

Expand Down

0 comments on commit ba0be44

Please sign in to comment.