From 2b2802a3fe102a68db2536b8eb61a91b2343a243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C4=8Cert=C3=ADk?= Date: Fri, 1 Mar 2019 13:27:54 -0700 Subject: [PATCH] Add ASG module --- build0.sh | 2 + grammar/ASG.asdl | 103 +++++++++++++++++++++++++++++++++++++ lfortran/asg/__init__.py | 0 lfortran/asg/asg_check.py | 38 ++++++++++++++ lfortran/asg/asg_to_ast.py | 90 ++++++++++++++++++++++++++++++++ lfortran/asg/builder.py | 78 ++++++++++++++++++++++++++++ 6 files changed, 311 insertions(+) create mode 100644 grammar/ASG.asdl create mode 100644 lfortran/asg/__init__.py create mode 100644 lfortran/asg/asg_check.py create mode 100644 lfortran/asg/asg_to_ast.py create mode 100644 lfortran/asg/builder.py diff --git a/build0.sh b/build0.sh index 7640c94c2e..aaf8c62ad3 100755 --- a/build0.sh +++ b/build0.sh @@ -5,6 +5,8 @@ set -x # Generate a Fortran AST from Fortran.asdl python grammar/asdl_py.py +# Generate a Fortran ASG from ASG.asdl +python grammar/asdl_py.py grammar/ASG.asdl lfortran/asg/asg.py ..ast.utils # Generate a parse tree from fortran.g4 antlr4="java org.antlr.v4.Tool" diff --git a/grammar/ASG.asdl b/grammar/ASG.asdl new file mode 100644 index 0000000000..3f2a3c8c17 --- /dev/null +++ b/grammar/ASG.asdl @@ -0,0 +1,103 @@ +-- Abstract Semantic Graph (ASG) definition + +-- ASDL's builtin types are: +-- * identifier +-- * int (signed integers of infinite precision) +-- * string +-- We extend these by: +-- * object (any Python object) +-- * constant +-- * symbol_table (scoped Symbol Table implementation), for now we use object +-- +-- Note: `symbol_table` contains `mod`, `sub`, `fn`, `var`. + +module ASG { + +-- FIXME: these functions might be implemented by hand in the code. `sub` and +-- `fn` is used once, and it should be changed to a pointer to +-- Function/Subroutine. `var` is used a lot, but again it should be a pointer to +-- Variable, in the symbol table. + +prog + = Program(identifier name, object symtab) + +mod + = Module(identifier name, object symtab) + +sub + = Subroutine() + attributes (identifier name, expr* args, stmt* body, tbind? bind, + object symtab) + +fn + = Function(expr return_var) + attributes (identifier name, expr* args, stmt* body, tbind? bind, + object symtab) + +stmt + = Assignment(expr target, expr value) + | SubroutineCall(sub name, expr* args) + | BuiltinCall(identifier name, expr* args) + | If(expr test, stmt* body, stmt* orelse) + | Where(expr test, stmt* body, stmt* orelse) + | Stop(int? code) + | ErrorStop() + | DoLoop(do_loop_head? head, stmt* body) + | Select(expr test, case_stmt* body, case_default? default) + | Cycle() + | Exit() + | WhileLoop(expr test, stmt* body) + | Print(string? fmt, expr* values) + +expr + = BoolOp(expr left, boolop op, expr right) + | BinOp(expr left, operator op, expr right) + | UnaryOp(unaryop op, expr operand) + | Compare(expr left, cmpop op, expr right) + | FuncCall(fn func, expr* args, keyword* keywords) + | Array(identifier name, array_index* args) + | ArrayInitializer(expr* args) + | Num(object n) + | Str(string s) + | Variable(identifier name, string? intent, int? dummy, object? scope) + | Constant(constant value) + attributes (ttype type) + +ttype + = Integer(int kind) + | Real(int kind) + | Complex(int kind) + | Character(int kind) + | Logical(int kind) + | Derived() + attributes (dimension* dims) + +boolop = And | Or + +operator = Add | Sub | Mul | Div | Pow + +unaryop = Invert | Not | UAdd | USub + +cmpop = Eq | NotEq | Lt | LtE | Gt | GtE + +dimension = (expr? start, expr? end) + +attribute = Attribute(identifier name, attribute_arg *args) + +attribute_arg = (identifier arg) + +arg = (identifier arg) + +keyword = (identifier? arg, expr value) + +tbind = Bind(keyword* args) + +array_index = ArrayIndex(expr? left, expr? right, expr? step) + +do_loop_head = (expr v, expr start, expr end, expr? increment) + +case_stmt = (expr test, stmt* body) + +case_default = (stmt* body) + +} diff --git a/lfortran/asg/__init__.py b/lfortran/asg/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lfortran/asg/asg_check.py b/lfortran/asg/asg_check.py new file mode 100644 index 0000000000..6dd5d0b1d8 --- /dev/null +++ b/lfortran/asg/asg_check.py @@ -0,0 +1,38 @@ +""" +# ASG Check + +This goes over the whole ASG and checks that all requirements are met: + +* It is a valid Fortran code +* All additional internal consistency requirements are satisfied + +This is not meant to report nice user errors, this is only meant to be run in +Debug mode to ensure that LFortran always constructs ASG in the correct form. + +If one *knows* (by checking in Debug mode) that a given algorithm constructs +ASG in the correct form, then one can construct ASG directly using the +classes in the `asg.asg` module. Otherwise one should use the `asg.builder` +module, which will always construct ASG in the correct form, or report a nice +error (that can then be forwarded to the user by LFortran) even in both Debug +and Release modes. The `asg.builder` is built to be robust and handle any +(valid or invalid) input. + +The semantic phase then traverses the AST and uses `asg.builder` to construct +ASG. Thus the `asg.builder` does most of the semantic checks for the semantic +analyzer (which only forwards the errors to the user), thus greatly simplifying +the semantic part of the compiler. + +The hard work of doing semantic checks is encoded in the ASG module, which +does not depend on the rest of LFortran and can be used, verified and +improved independently. +""" + +# TODO: Make this a visitor + +def check_function(f): + for arg in f.args: + assert arg.name in f.symtab.symbols + assert arg.dummy == True + assert f.return_var.name in f.symtab.symbols + assert f.return_var.dummy == True + assert f.return_var.intent is None \ No newline at end of file diff --git a/lfortran/asg/asg_to_ast.py b/lfortran/asg/asg_to_ast.py new file mode 100644 index 0000000000..6d69e2b2ea --- /dev/null +++ b/lfortran/asg/asg_to_ast.py @@ -0,0 +1,90 @@ +from ..ast import ast +from . import asg + +class ASG2ASTVisitor(asg.ASTVisitor): + + def visit_sequence(self, seq): + r = [] + if seq is not None: + for node in seq: + r.append(self.visit(node)) + return r + + def visit_Module(self, node): + decl = [] + contains = [] + for s in node.symtab.symbols: + sym = node.symtab.symbols[s] + if isinstance(sym, asg.Function): + if sym.body: + contains.append(self.visit(sym)) + else: + decl.append( + ast.Interface2(procs=[self.visit(sym)]) + ) + else: + raise NotImplementedError() + return ast.Module(name=node.name, decl=decl, contains=contains) + + def visit_Assignment(self, node): + target = self.visit(node.target) + value = self.visit(node.value) + return ast.Assignment(target, value) + + def visit_BinOp(self, node): + left = self.visit(node.left) + right = self.visit(node.right) + if isinstance(node.op, asg.Add): + op = ast.Add() + elif isinstance(node.op, asg.Mul): + op = ast.Mul() + else: + raise NotImplementedError() + return ast.BinOp(left, op, right) + + def visit_Variable(self, node): + return ast.Name(id=node.name) + + def visit_Num(self, node): + return ast.Num(n=node.n) + + def visit_Integer(self, node): + if node.kind == 4: + return "integer" + else: + return "integer(kind=%d)" % node.kind + + def visit_Function(self, node): + body = self.visit_sequence(node.body) + args = [] + decl = [] + for arg in node.args: + args.append(ast.arg(arg=arg.name)) + stype = self.visit(arg.type) + attrs = [] + if arg.intent: + attrs = [ + ast.Attribute(name="intent", + args=[ast.attribute_arg(arg=arg.intent)]), + ] + decl.append(ast.Declaration(vars=[ + ast.decl(sym=arg.name, sym_type=stype, + attrs=attrs)])) + for s in node.symtab.symbols: + sym = node.symtab.symbols[s] + if sym.dummy: + continue + stype = self.visit(sym.type) + decl.append(ast.Declaration(vars=[ + ast.decl(sym=sym.name, sym_type=stype)])) + return_type = self.visit(node.return_var.type) + return_var = self.visit(node.return_var) + return ast.Function( + name=node.name, args=args, return_type=return_type, + return_var=return_var, + decl=decl, body=body) + + +def asg_to_ast(a): + v = ASG2ASTVisitor() + return v.visit(a) \ No newline at end of file diff --git a/lfortran/asg/builder.py b/lfortran/asg/builder.py new file mode 100644 index 0000000000..54671e3f23 --- /dev/null +++ b/lfortran/asg/builder.py @@ -0,0 +1,78 @@ +""" +# ASG Builder + +Using the ASG builder has the following advantages over constructing the ASG +directly: + +* The ASG is constructed correctly, or a nice error is given +* Is easier to use, for example it handles the scoped symbol table + automatically + +""" + +from . import asg +from .asg_check import check_function +from ..semantic.analyze import Scope + +# Private: + +def _add_symbol(scope, v): + scope.symbols[v.name] = v + v._scope = scope + +def _add_var(scope, v, dummy=False): + v.scope = scope + v.dummy = dummy + _add_symbol(scope, v) + +# Public API: + +def make_type_integer(kind=None): + if not kind: + kind = 4 + return asg.Integer(kind=kind) + + +class FunctionBuilder(): + + def __init__(self, mod, name, args=[], return_var=None, body=[]): + assert isinstance(mod, asg.Module) + scope = mod.symtab + self._name = name + self._args = args.copy() + self._body = body.copy() + self._return_var = return_var + self._parent_scope = scope + self._function_scope = Scope(self._parent_scope) + for arg in args: + _add_var(self._function_scope, arg, dummy=True) + if return_var: + _add_var(self._function_scope, return_var, dummy=True) + + def make_var(self, name, type): + v = asg.Variable(name=name, dummy=False, type=type) + _add_var(self._function_scope, v, dummy=False) + return v + + def add_statements(self, statements): + assert isinstance(statements, list) + self._body += statements + + def finalize(self): + f = asg.Function(name=self._name, symtab=self._function_scope, + args=self._args, return_var=self._return_var, body=self._body) + _add_symbol(self._parent_scope, f) + check_function(f) + return f + +class TranslationUnit(): + + def __init__(self): + self._global_scope = Scope() + pass + + def make_module(self, name): + module_scope = Scope(self._global_scope) + m = asg.Module(name=name, symtab=module_scope) + _add_symbol(self._global_scope, m) + return m