Skip to content

Commit

Permalink
feat: Upgrade Hugr and start using the shared Pydantic model (#201)
Browse files Browse the repository at this point in the history
This PR makes Guppy depend on Hugr at main as a path for upgrading to
v0.4. Closes #198

Most of the edits are related to changes in the Pydantic model. However,
we also had to get rid of `TypeApply` since this no longer exists in
Hugr. Therefore, Guppy now rejects all programs that use polymorphic
functions as values if the type arguments cannot be inferred
  • Loading branch information
mark-koch authored May 16, 2024
1 parent f7adb85 commit bd7e67a
Show file tree
Hide file tree
Showing 51 changed files with 1,365 additions and 1,630 deletions.
40 changes: 18 additions & 22 deletions guppylang/compiler/cfg_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
)
from guppylang.compiler.expr_compiler import ExprCompiler
from guppylang.compiler.stmt_compiler import StmtCompiler
from guppylang.hugr.hugr import CFNode, Hugr, Node, OutPortV
from guppylang.hugr_builder.hugr import CFNode, Hugr, Node, OutPortV
from guppylang.tys.builtin import is_bool_type
from guppylang.tys.ty import SumType, TupleType, type_to_row
from guppylang.tys.ty import SumType, row_to_type, type_to_row

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -65,9 +65,8 @@ def compile_bb(
branch_port = ExprCompiler(graph, globals).compile(bb.branch_pred, dfg)
else:
# Even if we don't branch, we still have to add a `Sum(())` predicates
unit = graph.add_make_tuple([], parent=block).out_port(0)
branch_port = graph.add_tag(
variants=[TupleType([])], tag=0, inp=unit, parent=block
variants=[[]], tag=0, inputs=[], parent=block
).out_port(0)

# Finally, we have to add the block output.
Expand All @@ -92,7 +91,11 @@ def compile_bb(
graph=graph,
unit_sum=branch_port,
output_vars=[
[v for v in row if not v.ty.linear or is_return_var(v.name)]
[
v
for v in sort_vars(row)
if not v.ty.linear or is_return_var(v.name)
]
for row in bb.sig.output_rows
],
dfg=dfg,
Expand Down Expand Up @@ -133,30 +136,23 @@ def choose_vars_for_tuple_sum(
) -> OutPortV:
"""Selects an output based on a TupleSum.
Given `unit_sum: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`,
constructs a TupleSum value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`.
Given `unit_sum: Sum(*(), *(), ...)` and output variable rows `#s1, #s2, ...`,
constructs a TupleSum value of type `Sum(#s1, #s2, ...)`.
"""
assert isinstance(unit_sum.ty, SumType) or is_bool_type(unit_sum.ty)
assert len(output_vars) == (
len(unit_sum.ty.element_types) if isinstance(unit_sum.ty, SumType) else 2
)
tuples = [
graph.add_make_tuple(
inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg],
parent=dfg.node,
).out_port(0)
for vs in output_vars
]
tys = [t.ty for t in tuples]
conditional = graph.add_conditional(
cond_input=unit_sum, inputs=tuples, parent=dfg.node
)
for i, _ty in enumerate(tys):
assert all(not v.ty.linear for var_row in output_vars for v in var_row)
conditional = graph.add_conditional(cond_input=unit_sum, inputs=[], parent=dfg.node)
tys = [[v.ty for v in var_row] for var_row in output_vars]
for i, var_row in enumerate(output_vars):
case = graph.add_case(conditional)
inp = graph.add_input(output_tys=tys, parent=case).out_port(i)
tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0)
graph.add_input(output_tys=[], parent=case)
inputs = [dfg[v.name].port for v in var_row]
tag = graph.add_tag(variants=tys, tag=i, inputs=inputs, parent=case).out_port(0)
graph.add_output(inputs=[tag], parent=case)
return conditional.add_out_port(SumType(tys))
return conditional.add_out_port(SumType([row_to_type(row) for row in tys]))


def compare_var(x: Variable, y: Variable) -> int:
Expand Down
2 changes: 1 addition & 1 deletion guppylang/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from guppylang.ast_util import AstNode
from guppylang.checker.core import Variable
from guppylang.definition.common import CompiledDef, DefId
from guppylang.hugr.hugr import DFContainingNode, Hugr, OutPortV
from guppylang.hugr_builder.hugr import DFContainingNode, Hugr, OutPortV


@dataclass
Expand Down
72 changes: 48 additions & 24 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
from typing import Any, TypeGuard, TypeVar

from hugr.serialization import ops

from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type
from guppylang.cfg.builder import tmp_vars
Expand All @@ -13,8 +15,13 @@
)
from guppylang.definition.value import CompiledCallableDef, CompiledValueDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr import ops, val
from guppylang.hugr.hugr import DFContainingNode, OutPortV, VNode
from guppylang.hugr_builder.hugr import (
UNDEFINED,
DFContainingNode,
DummyOp,
OutPortV,
VNode,
)
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
Expand Down Expand Up @@ -149,6 +156,12 @@ def visit_LocalName(self, node: LocalName) -> OutPortV:
def visit_GlobalName(self, node: GlobalName) -> OutPortV:
defn = self.globals[node.def_id]
assert isinstance(defn, CompiledValueDef)
if isinstance(defn, CompiledCallableDef) and defn.ty.parametrized:
raise GuppyError(
"Usage of polymorphic functions as dynamic higher-order values is not "
"supported yet",
node,
)
return defn.load(self.dfg, self.graph, self.globals, node)

def visit_Name(self, node: ast.Name) -> OutPortV:
Expand All @@ -162,7 +175,7 @@ def visit_Tuple(self, node: ast.Tuple) -> OutPortV:
def visit_List(self, node: ast.List) -> OutPortV:
# Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension
return self.graph.add_node(
ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts]
DummyOp("MakeList"), inputs=[self.visit(e) for e in node.elts]
).add_out_port(get_type(node))

def _unpack_tuple(self, wire: OutPortV) -> list[OutPortV]:
Expand Down Expand Up @@ -243,9 +256,11 @@ def visit_Call(self, node: ast.Call) -> OutPortV:
raise InternalGuppyError("Node should have been removed during type checking.")

def visit_TypeApply(self, node: TypeApply) -> OutPortV:
func = self.visit(node.value)
assert isinstance(func.ty, FunctionType)
ta = self.graph.add_type_apply(func, node.inst, self.dfg.node).out_port(0)
# For now, we can only TypeApply global FunctionDefs/Decls.
if not isinstance(node.value, GlobalName):
raise InternalGuppyError("Dynamic TypeApply not supported yet!")
defn = self.globals[node.value.def_id]
assert isinstance(defn, CompiledCallableDef)

# We have to be very careful here: If we instantiate `foo: forall T. T -> T`
# with a tuple type `tuple[A, B]`, we get the type `tuple[A, B] -> tuple[A, B]`.
Expand All @@ -254,22 +269,25 @@ def visit_TypeApply(self, node: TypeApply) -> OutPortV:
# function with a single output port typed `tuple[A, B]`.
# TODO: We would need to do manual monomorphisation in that case to obtain a
# function that returns two ports as expected
if instantiation_needs_unpacking(func.ty, node.inst):
if instantiation_needs_unpacking(defn.ty, node.inst):
raise GuppyError(
"Generic function instantiations returning rows are not supported yet",
node,
)

return ta
return defn.load_with_args(node.inst, self.dfg, self.graph, self.globals, node)

def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV:
# The only case that is not desugared by the type checker is the `not` operation
# since it is not implemented via a dunder method
if isinstance(node.op, ast.Not):
arg = self.visit(node.operand)
return self.graph.add_node(
ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg]
).add_out_port(bool_type())
op = ops.CustomOp(
extension="logic", op_name="Not", args=[], parent=UNDEFINED
)
return self.graph.add_node(ops.OpType(op), inputs=[arg]).add_out_port(
bool_type()
)

raise InternalGuppyError("Node should have been removed during type checking.")

Expand All @@ -281,7 +299,7 @@ def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV:
# Make up a name for the list under construction and bind it to an empty list
list_ty = get_type(node)
list_name = with_type(list_ty, with_loc(node, LocalName(id=next(tmp_vars))))
empty_list = self.graph.add_node(ops.DummyOp(name="MakeList"))
empty_list = self.graph.add_node(DummyOp("MakeList"))
self.dfg[list_name.id] = PortVariable(
list_name.id, empty_list.add_out_port(list_ty), node, None
)
Expand All @@ -292,7 +310,7 @@ def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None:
if not gens:
list_port, elt_port = self.visit(list_name), self.visit(elt)
push = self.graph.add_node(
ops.DummyOp(name="Push"), inputs=[list_port, elt_port]
DummyOp("Push"), inputs=[list_port, elt_port]
)
self.dfg[list_name.id].port = push.add_out_port(list_port.ty)
return
Expand Down Expand Up @@ -348,7 +366,7 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
return False


def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None:
def python_value_to_hugr(v: Any, exp_ty: Type) -> ops.Value | None:
"""Turns a Python value into a Hugr value.
Returns None if the Python value cannot be represented in Guppy.
Expand All @@ -373,15 +391,13 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None:
python_value_to_hugr(elt, ty)
for elt, ty in zip(elts, exp_ty.element_types)
]
if any(value is None for value in vs):
return None
return val.Tuple(vs=vs)
if doesnt_contain_none(vs):
return ops.Value(ops.TupleValue(vs=vs))
case list(elts):
assert is_list_type(exp_ty)
return list_value(
[python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts],
get_element_type(exp_ty).to_hugr(),
)
vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts]
if doesnt_contain_none(vs):
return list_value(vs, get_element_type(exp_ty))
case _:
# Pytket conversion is an optional feature
try:
Expand All @@ -393,7 +409,15 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None:
)

hugr = json.loads(Tk2Circuit(v).to_hugr_json())
return val.FunctionVal(hugr=hugr)
return ops.Value(ops.FunctionValue(hugr=hugr))
except ImportError:
pass
return None
return None


T = TypeVar("T")


def doesnt_contain_none(xs: list[T | None]) -> TypeGuard[list[T]]:
"""Checks if a list contains `None`."""
return all(x is not None for x in xs)
6 changes: 3 additions & 3 deletions guppylang/compiler/func_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
DFContainer,
PortVariable,
)
from guppylang.hugr.hugr import DFContainingVNode, Hugr
from guppylang.hugr_builder.hugr import DFContainingVNode, Hugr
from guppylang.nodes import CheckedNestedFunctionDef
from guppylang.tys.ty import FunctionType, type_to_row

Expand Down Expand Up @@ -60,7 +60,7 @@ def compile_local_func_def(
# the function itself, then we provide the partially applied function as a local
# variable
if len(captured) > 0 and func.name in func.cfg.live_before[func.cfg.entry_bb]:
loaded = graph.add_load_constant(def_node.out_port(0), def_node).out_port(0)
loaded = graph.add_load_function(def_node.out_port(0), [], def_node).out_port(0)
partial = graph.add_partial(
loaded, [def_input.out_port(i) for i in range(len(captured))], def_node
)
Expand Down Expand Up @@ -93,7 +93,7 @@ def compile_local_func_def(
)

# Finally, load the function into the local data-flow graph
loaded = graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0)
loaded = graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0)
if len(captured) > 0:
loaded = graph.add_partial(
loaded, [dfg[v.name].port for v in captured], dfg.node
Expand Down
2 changes: 1 addition & 1 deletion guppylang/compiler/stmt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from guppylang.compiler.expr_compiler import ExprCompiler
from guppylang.error import InternalGuppyError
from guppylang.hugr.hugr import Hugr, OutPortV
from guppylang.hugr_builder.hugr import Hugr, OutPortV
from guppylang.nodes import CheckedNestedFunctionDef
from guppylang.tys.ty import TupleType

Expand Down
5 changes: 3 additions & 2 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from types import ModuleType
from typing import Any, TypeVar

from hugr.serialization import ops, tys

from guppylang.ast_util import has_empty_body
from guppylang.definition.common import DefId
from guppylang.definition.custom import (
Expand All @@ -21,8 +23,7 @@
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError, MissingModuleError, pretty_errors
from guppylang.hugr import ops, tys
from guppylang.hugr.hugr import Hugr
from guppylang.hugr_builder.hugr import Hugr
from guppylang.module import GuppyModule, PyFunc

FuncDefDecorator = Callable[[PyFunc], RawFunctionDef]
Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, TypeAlias

from guppylang.hugr.hugr import Hugr, Node
from guppylang.hugr_builder.hugr import Hugr, Node

if TYPE_CHECKING:
from guppylang.checker.core import Globals
Expand Down
29 changes: 15 additions & 14 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

from hugr.serialization import ops

from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.checker.core import Context, Globals
from guppylang.checker.expr_checker import check_call, synthesize_call
Expand All @@ -10,8 +12,7 @@
from guppylang.definition.common import ParsableDef
from guppylang.definition.value import CompiledCallableDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr import ops
from guppylang.hugr.hugr import Hugr, Node, OutPortV
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row
Expand Down Expand Up @@ -123,8 +124,13 @@ def synthesize_call(
new_node, ty = self.call_checker.synthesize(args)
return with_type(ty, with_loc(node, new_node)), ty

def load(
self, dfg: "DFContainer", graph: Hugr, globals: CompiledGlobals, node: AstNode
def load_with_args(
self,
type_args: Inst,
dfg: "DFContainer",
graph: Hugr,
globals: CompiledGlobals,
node: AstNode,
) -> OutPortV:
"""Loads the custom function as a value into a local dataflow graph.
Expand All @@ -138,12 +144,7 @@ def load(
"This function does not support usage in a higher-order context",
node,
)

if self.ty.parametrized:
raise InternalGuppyError(
"Can't yet generate higher-order versions of custom functions. This "
"requires generic function *definitions*"
)
assert len(self.ty.params) == len(type_args)

# Find the module node by walking up the hierarchy
module: Node = dfg.node
Expand All @@ -159,7 +160,7 @@ def load(
def_node = graph.add_def(self.ty, module, self.name)
_, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node)
returns = self.compile_call(
inp_ports, [], DFContainer(def_node, {}), graph, globals, node
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node
)
graph.add_output(returns, parent=def_node)

Expand Down Expand Up @@ -251,14 +252,14 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]:
class OpCompiler(CustomCallCompiler):
"""Call compiler for functions that are directly implemented via Hugr ops."""

op: ops.BaseOp
op: ops.OpType

def __init__(self, op: ops.BaseOp) -> None:
def __init__(self, op: ops.OpType) -> None:
self.op = op

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
node = self.graph.add_node(
self.op.model_copy(), inputs=args, parent=self.dfg.node
self.op.model_copy(deep=True), inputs=args, parent=self.dfg.node
)
return_ty = get_type(self.node)
return [node.add_out_port(ty) for ty in type_to_row(return_ty)]
Expand Down
Loading

0 comments on commit bd7e67a

Please sign in to comment.