Skip to content

Commit

Permalink
Update asdl_cpp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
certik committed Jun 4, 2022
1 parent 5e9203a commit 51f5689
Showing 1 changed file with 244 additions and 5 deletions.
249 changes: 244 additions & 5 deletions grammar/asdl_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def visitModule(self, mod):
self.emit( "void inc_lindent() {", 1)
self.emit( "indent_level++;", 2)
self.emit( 'indtd += "| ";', 2)
self.emit( "}", 1)
self.emit( "}", 1)
self.emit( "void dec_indent() {", 1)
self.emit( "indent_level--;", 2)
self.emit( "LFORTRAN_ASSERT(indent_level >= 0);", 2)
Expand Down Expand Up @@ -535,9 +535,9 @@ def visitField(self, field, cons, last):
self.emit("self().visit_%s(*x.m_%s[i]);" % (field.type, field.name), level+1)
else:
self.emit("self().visit_%s(x.m_%s[i]);" % (field.type, field.name), level+1)
self.emit( 'dec_indent();', level+1)
self.emit( 'dec_indent();', level+1)
self.emit("}", level)
elif field.opt:
elif field.opt:
self.emit('s.append("\\n" + indtd + "%s" + "%s=");' % (arr, field.name), 2)
if last:
self.emit('last = true;', 2)
Expand All @@ -562,7 +562,7 @@ def visitField(self, field, cons, last):
self.emit('s.append("\\n" + indtd + "%s" + "%s=");' % (arr, field.name), level)
self.emit("for (size_t i=0; i<x.n_%s; i++) {" % field.name, level)
self.emit( "s.append(x.m_%s[i]);" % (field.name), level+1)
self.emit( 'if (i < x.n_%s-1) s.append(" ");' % (field.name), level+1)
self.emit( 'if (i < x.n_%s-1) s.append(" ");' % (field.name), level+1)
self.emit("}", level)
else:
if field.opt:
Expand All @@ -587,7 +587,7 @@ def visitField(self, field, cons, last):
self.emit( 'last = i == x.n_%s-1;' % field.name, level+1)
self.emit( 'attached = false;', level+1)
self.emit( "self().visit_%s(*x.m_%s[i]);" % (mod_name, field.name), level+1)
self.emit( 'dec_indent();', level+1)
self.emit( 'dec_indent();', level+1)
self.emit("}", level)
elif field.type == "symbol_table":
assert not field.opt
Expand Down Expand Up @@ -815,6 +815,103 @@ def visitField(self, field):
arguments = ("x->m_" + field.name, )
return arguments

class ExprBaseReplacerVisitor(ASDLVisitor):

def __init__(self, stream, data):
self.replace_expr = []
self.is_expr = False
self.is_product = False
super(ExprBaseReplacerVisitor, self).__init__(stream, data)

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Expression Replacer Base class")
self.emit("")
self.emit("template <class Derived>")
self.emit("class BaseExprReplacer {")
self.emit("public:")
self.emit(" Derived& self() { return static_cast<Derived&>(*this); }")
self.emit("")
self.emit(" ASR::expr_t** current_expr;")
self.emit(" ASR::expr_t** current_expr_copy;")
self.emit("")
self.emit(" BaseExprReplacer() : current_expr(nullptr) {}")
self.emit("")

self.replace_expr.append((" void replace_expr(ASR::expr_t* x) {", 0))
self.replace_expr.append((" if( !x ) {", 1))
self.replace_expr.append((" return ;", 2))
self.replace_expr.append((" }", 1))
self.replace_expr.append(("", 0))
self.replace_expr.append((" switch(x->type) {", 1))

super(ExprBaseReplacerVisitor, self).visitModule(mod)

self.replace_expr.append((" default: {", 2))
self.replace_expr.append((' LFORTRAN_ASSERT_MSG(false, "Duplication of " + std::to_string(x->type) + " expression is not supported yet.");', 3))
self.replace_expr.append((" }", 2))
self.replace_expr.append((" }", 1))
self.replace_expr.append(("", 0))
self.replace_expr.append((" }", 0))
for line, level in self.replace_expr:
self.emit(line, level=level)
self.emit("")
self.emit("};")

def visitType(self, tp):
if not (isinstance(tp.value, asdl.Sum) and
is_simple_sum(tp.value)):
super(ExprBaseReplacerVisitor, self).visitType(tp, tp.name)

def visitSum(self, sum, *args):
self.is_expr = args[0] == 'expr'
if self.is_expr:
for tp in sum.types:
self.visit(tp, *args)

def visitProduct(self, prod, name):
pass

def visitConstructor(self, cons, _):
self.make_visitor(cons.name, cons.fields)

def make_visitor(self, name, fields):
self.emit("")
self.emit("void replace_%s(%s_t* x) {" % (name, name), 1)
self.used = False
for field in fields:
self.visitField(field)
if not self.used:
self.emit("if (x) { }", 2)

if self.is_expr:
self.replace_expr.append((" case ASR::exprType::%s: {" % name, 2))
self.replace_expr.append((" self().replace_%s(down_cast<ASR::%s_t>(x));" % (name, name), 3))
self.replace_expr.append((" break;", 3))
self.replace_expr.append((" }", 2))
self.emit("}", 1)
self.emit("")

def visitField(self, field):
arguments = None
if field.type == "expr" or field.type == "symbol" or field.type == "call_arg":
level = 2
if field.seq:
self.used = True
self.emit("for (size_t i = 0; i < x->n_%s; i++) {" % field.name, level)
if field.type == "call_arg":
self.emit(" current_expr_copy = current_expr;", level)
self.emit(" current_expr = &(x->m_%s[i].m_value);" % (field.name), level)
self.emit(" self().replace_expr(x->m_%s[i].m_value);"%(field.name), level)
self.emit(" current_expr = current_expr_copy;", level)
self.emit("}", level)
else:
if field.type != "symbol":
self.used = True
self.emit("current_expr_copy = current_expr;", level)
self.emit("current_expr = &(x->m_%s);" % (field.name), level)
self.emit("self().replace_%s(x->m_%s);" % (field.type, field.name), level)
self.emit("current_expr = current_expr_copy;", level)


class PickleVisitorVisitor(ASDLVisitor):
Expand Down Expand Up @@ -1576,6 +1673,142 @@ def visitConstructor(self, cons, _):
self.emit( 'return %s::make_%s_t(%s);' % (subs["MOD"], name, ", ".join(args)), 2)
self.emit("}", 1)

class ExprTypeVisitor(ASDLVisitor):

def __init__(self, stream, data):
self.replace_expr = []
self.is_expr = False
self.is_product = False
super(ExprTypeVisitor, self).__init__(stream, data)

def emit(self, line, level=0, new_line=True):
indent = " "*level
self.stream.write(indent + line)
if new_line:
self.stream.write("\n")

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Expression Type (`expr_type`) visitor")
self.emit("""\
static inline ASR::ttype_t* expr_type0(const ASR::expr_t *f)
{
LFORTRAN_ASSERT(f != nullptr);
switch (f->type) {""")

super(ExprTypeVisitor, self).visitModule(mod)

self.emit(""" default : throw LFortranException("Not implemented");
}
}
""")

def visitType(self, tp):
if not (isinstance(tp.value, asdl.Sum) and
is_simple_sum(tp.value)):
super(ExprTypeVisitor, self).visitType(tp, tp.name)

def visitSum(self, sum, *args):
self.is_expr = args[0] == 'expr'
if self.is_expr:
for tp in sum.types:
self.visit(tp, *args)

def visitProduct(self, prod, name):
pass

def visitConstructor(self, cons, _):
self.make_visitor(cons.name, cons.fields)

def make_visitor(self, name, fields):
if name == "Var":
self.emit("""case ASR::exprType::%s: {
ASR::symbol_t *s = ((ASR::%s_t*)f)->m_v;
if (s->type == ASR::symbolType::ExternalSymbol) {
ASR::ExternalSymbol_t *e = ASR::down_cast<ASR::ExternalSymbol_t>(s);
LFORTRAN_ASSERT(!ASR::is_a<ASR::ExternalSymbol_t>(*e->m_external));
s = e->m_external;
}
return ASR::down_cast<ASR::Variable_t>(s)->m_type;
}""" \
% (name, name), 2, new_line=False)
else:
self.emit("case ASR::exprType::%s: { return ((ASR::%s_t*)f)->m_type; }"\
% (name, name), 2, new_line=False)
self.emit("")

def visitField(self, field):
pass

class ExprValueVisitor(ASDLVisitor):

def __init__(self, stream, data):
self.replace_expr = []
self.is_expr = False
self.is_product = False
super(ExprValueVisitor, self).__init__(stream, data)

def emit(self, line, level=0, new_line=True):
indent = " "*level
self.stream.write(indent + line)
if new_line:
self.stream.write("\n")

def visitModule(self, mod):
self.emit("/" + "*"*78 + "/")
self.emit("// Expression Value (`expr_value`) visitor")
self.emit("""\
static inline ASR::expr_t* expr_value0(ASR::expr_t *f)
{
LFORTRAN_ASSERT(f != nullptr);
switch (f->type) {""")

super(ExprValueVisitor, self).visitModule(mod)

self.emit(""" default : throw LFortranException("Not implemented");
}
}
""")

def visitType(self, tp):
if not (isinstance(tp.value, asdl.Sum) and
is_simple_sum(tp.value)):
super(ExprValueVisitor, self).visitType(tp, tp.name)

def visitSum(self, sum, *args):
self.is_expr = args[0] == 'expr'
if self.is_expr:
for tp in sum.types:
self.visit(tp, *args)

def visitProduct(self, prod, name):
pass

def visitConstructor(self, cons, _):
self.make_visitor(cons.name, cons.fields)

def make_visitor(self, name, fields):
if name == "Var":
self.emit("""case ASR::exprType::%s: {
ASR::symbol_t *s = ((ASR::%s_t*)f)->m_v;
if (s->type == ASR::symbolType::ExternalSymbol) {
ASR::ExternalSymbol_t *e = ASR::down_cast<ASR::ExternalSymbol_t>(s);
LFORTRAN_ASSERT(!ASR::is_a<ASR::ExternalSymbol_t>(*e->m_external));
s = e->m_external;
}
return ASR::down_cast<ASR::Variable_t>(s)->m_value;
}""" \
% (name, name), 2, new_line=False)
elif name.endswith("Constant") or name == "IntegerBOZ":
self.emit("case ASR::exprType::%s: { return f; }"\
% (name), 2, new_line=False)
else:
self.emit("case ASR::exprType::%s: { return ((ASR::%s_t*)f)->m_value; }"\
% (name, name), 2, new_line=False)
self.emit("")

def visitField(self, field):
pass

class ASDLData(object):

Expand Down Expand Up @@ -1732,6 +1965,12 @@ def main(argv):
if is_asr:
ExprStmtDuplicatorVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprBaseReplacerVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprTypeVisitor(fp, data).visit(mod)
fp.write("\n\n")
ExprValueVisitor(fp, data).visit(mod)
fp.write("\n\n")
fp.write(FOOT % subs)
finally:
fp.close()
Expand Down

0 comments on commit 51f5689

Please sign in to comment.