Skip to content

Commit

Permalink
feat(hugr-py): builder ops separate from serialised ops (#1140)
Browse files Browse the repository at this point in the history
also move the "number of input wires" interface convenience to calls on
ops

"builder ops" are currently very similar to the serialised ones - this
will change once builder types are separated from serialised types
  • Loading branch information
ss2165 authored Jun 4, 2024
1 parent ea8905a commit 342eda3
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 125 deletions.
66 changes: 13 additions & 53 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing_extensions import Self

from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.ops import BaseOp, OpType as SerialOp
import hugr.serialization.ops as sops
from hugr.serialization.tys import Type
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.tys import Type, FunctionType
from hugr._ops import Op, Input, Output, DFG, Command
from hugr.utils import BiMap


Expand Down Expand Up @@ -101,35 +101,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort:
return self.out(offset)


class Op(Protocol):
def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ...

@classmethod
def from_serial(cls, serial: SerialOp) -> Self: ...


T = TypeVar("T", bound=BaseOp)


@dataclass()
class DummyOp(Op, Generic[T]):
_serial_op: T

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=self._serial_op.model_copy()) # type: ignore

@classmethod
def from_serial(cls, serial: SerialOp) -> DummyOp:
return DummyOp(serial.root)


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None


@dataclass()
class NodeData:
op: Op
Expand All @@ -139,10 +110,9 @@ class NodeData:
# TODO children field?

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
o = self.op.to_serial(node, hugr)
o.root.parent = self.parent.idx if self.parent else node.idx
o = self.op.to_serial(node, self.parent if self.parent else node, hugr)

return o
return SerialOp(root=o) # type: ignore[arg-type]


P = TypeVar("P", InPort, OutPort)
Expand Down Expand Up @@ -372,7 +342,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr:
hugr.root = Node(idx)
parent = None
serial_node.root.parent = -1
hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent))
hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent))

for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
if src_offset is None or dst_offset is None:
Expand All @@ -396,43 +366,33 @@ def __init__(
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DummyOp(sops.DFG(parent=-1))
root_op._serial_op.signature.input = input_types
root_op._serial_op.signature.output = output_types
root_op = DFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
DummyOp(sops.Input(parent=0, types=input_types)),
self.root,
len(input_types),
)
self.output_node = self.hugr.add_node(
DummyOp(sops.Output(parent=0, types=output_types)), self.root
Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> DummyOp[sops.Input]:
def _input_op(self) -> Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, DummyOp)
assert isinstance(dop._serial_op, sops.Input)
assert isinstance(dop, Input)
return dop

def inputs(self) -> list[OutPort]:
return [
self.input_node.out(i)
for i in range(len(self._input_op()._serial_op.types))
]
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out())
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
Expand Down
135 changes: 135 additions & 0 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
import hugr.serialization.tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire


class Op(Protocol):
@property
def num_out(self) -> int | None:
return None

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ...

def __call__(self, *args) -> Command:
return Command(self, list(args))


@dataclass(frozen=True)
class Command:
op: Op
incoming: list[Wire]


T = TypeVar("T", bound=BaseOp)


@dataclass()
class SerWrap(Op, Generic[T]):
# catch all for serial ops that don't have a corresponding Op class
_serial_op: T

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:
root = self._serial_op.model_copy()
root.parent = parent.idx
return root


@dataclass()
class Input(Op):
types: list[tys.Type]

@property
def num_out(self) -> int | None:
return len(self.types)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=self.types)

def __call__(self) -> Command:
return super().__call__()


@dataclass()
class Output(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=self.types)


@dataclass()
class Custom(Op):
op_name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
description: str = ""
extension: tys.ExtensionId = ""
args: list[tys.TypeArg] = field(default_factory=list)

@property
def num_out(self) -> int | None:
return len(self.signature.output)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
return sops.CustomOp(
parent=parent.idx,
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
description=self.description,
args=self.args,
)


@dataclass()
class MakeTuple(Op):
types: list[tys.Type]
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=self.types,
)

def __call__(self, *elements: Wire) -> Command:
return super().__call__(*elements)


@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]

@property
def num_out(self) -> int | None:
return len(self.types)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
return sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
)

def __call__(self, tuple_: Wire) -> Command:
return super().__call__(tuple_)


@dataclass()
class DFG(Op):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

@property
def num_out(self) -> int | None:
return len(self.signature.output)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG:
return sops.DFG(
parent=parent.idx,
signature=self.signature,
)
30 changes: 30 additions & 0 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import inspect
import sys
from abc import ABC
Expand Down Expand Up @@ -41,6 +42,10 @@ def display_name(self) -> str:
"""Name of the op for visualisation"""
return self.__class__.__name__

def deserialize(self) -> _ops.Op:
"""Deserializes the model into the corresponding Op."""
return _ops.SerWrap(self)


# ----------------------------------------------------------
# --------------- Module level operations ------------------
Expand Down Expand Up @@ -209,6 +214,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
assert len(in_types) == 0
self.types = list(out_types)

def deserialize(self) -> _ops.Input:
return _ops.Input(types=self.types)


class Output(DataflowOp):
"""An output node. The inputs are the outputs of the function."""
Expand All @@ -220,6 +228,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
assert len(out_types) == 0
self.types = list(in_types)

def deserialize(self) -> _ops.Output:
return _ops.Output(types=self.types)


class Call(DataflowOp):
"""
Expand Down Expand Up @@ -292,6 +303,9 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([])
)

def deserialize(self) -> _ops.DFG:
return _ops.DFG(self.signature)


# ------------------------------------------------
# --------------- ControlFlowOp ------------------
Expand Down Expand Up @@ -388,6 +402,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
def display_name(self) -> str:
return self.op_name

def deserialize(self) -> _ops.Custom:
return _ops.Custom(
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
args=self.args,
)

model_config = ConfigDict(
# Needed to avoid random '\n's in the pydantic description
json_schema_extra={
Expand Down Expand Up @@ -424,6 +446,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
in_types = []
self.tys = list(in_types)

def deserialize(self) -> _ops.MakeTuple:
return _ops.MakeTuple(self.tys)


class UnpackTuple(DataflowOp):
"""An operation that packs all its inputs into a tuple."""
Expand All @@ -434,6 +459,9 @@ class UnpackTuple(DataflowOp):
def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.tys = list(out_types)

def deserialize(self) -> _ops.UnpackTuple:
return _ops.UnpackTuple(self.tys)


class Tag(DataflowOp):
"""An operation that creates a tagged sum value from one of its variants."""
Expand Down Expand Up @@ -529,3 +557,5 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True):
)

tys_model_rebuild(dict(classes))

from hugr import _ops # noqa: E402 # needed to avoid circular imports
Loading

0 comments on commit 342eda3

Please sign in to comment.