Skip to content

Commit

Permalink
dodgy support for commands
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed May 22, 2024
1 parent 669e5ed commit de22b5a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 21 deletions.
33 changes: 26 additions & 7 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, replace
from dataclasses import dataclass, field, replace

from collections.abc import Collection, Mapping
from typing import Iterable, Sequence, Protocol, Generic, TypeVar, overload
Expand Down Expand Up @@ -35,6 +35,7 @@ def to_port(self) -> "OutPort":
@dataclass(frozen=True, eq=True, order=True)
class Node(ToPort):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)

@overload
def __getitem__(self, index: int) -> OutPort: ...
Expand All @@ -48,12 +49,17 @@ def __getitem__(
) -> OutPort | Iterable[OutPort]:
match index:
case int(index):
if self._num_out_ports is not None:
if index >= self._num_out_ports:
raise IndexError("Index out of range")
return self.out(index)
case slice():
start = index.start or 0
stop = index.stop
stop = index.stop or self._num_out_ports
if stop is None:
raise ValueError("Stop must be specified")
raise ValueError(
"Stop must be specified when number of outputs unknown"
)
step = index.step or 1
return (self[i] for i in range(start, stop, step))
case tuple(xs):
Expand Down Expand Up @@ -84,6 +90,13 @@ def to_serial(self, node: Node, hugr: "Hugr") -> SerialOp:
return SerialOp(root=self._serial_op) # type: ignore


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[ToPort]: ...
def num_out(self) -> int | None:
return None


@dataclass()
class NodeData:
op: Op
Expand Down Expand Up @@ -144,6 +157,7 @@ def add_node(
self,
op: Op,
parent: Node | None = None,
num_outs: int | None = None,
) -> Node:
node_data = NodeData(op, parent)

Expand All @@ -153,7 +167,7 @@ def add_node(
else:
node = Node(len(self._nodes))
self._nodes.append(node_data)
return node
return replace(node, _num_out_ports=num_outs)

def delete_node(self, node: Node) -> NodeData | None:
for offset in range(self.num_in_ports(node)):
Expand Down Expand Up @@ -245,7 +259,9 @@ def __init__(
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
DummyOp(sops.Input(parent=0, types=input_types)), self.root
DummyOp(sops.Input(parent=0, types=input_types)),
self.root,
len(input_types),
)
self.output_node = self.hugr.add_node(
DummyOp(sops.Output(parent=0, types=output_types)), self.root
Expand All @@ -267,11 +283,14 @@ def inputs(self) -> list[OutPort]:
for i in range(len(self._input_op()._serial_op.types))
]

def add_op(self, op: Op, /, *args: ToPort) -> Node:
new_n = self.hugr.add_node(op, self.root)
def add_op(self, op: Op, /, *args: ToPort, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out())

def insert_nested(self, dfg: "Dfg", *args: ToPort) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], args)
Expand Down
56 changes: 42 additions & 14 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
import subprocess

from hugr.hugr import Dfg, Hugr, DummyOp, Node
from hugr.hugr import Dfg, Hugr, DummyOp, Node, Command, ToPort, Op
import hugr.serialization.tys as stys
import hugr.serialization.ops as sops
import pytest
Expand All @@ -27,15 +28,42 @@
)
)

DIV_OP = DummyOp(
sops.CustomOp(
parent=-1,
extension="arithmetic.int",
op_name="idivmod_u",
signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2),
args=[ARG_5, ARG_5],
)
)

@dataclass
class Not(Command):
a: ToPort

def incoming(self) -> list[ToPort]:
return [self.a]

def num_out(self) -> int | None:
return 1

def op(self) -> Op:
return NOT_OP


@dataclass
class DivMod(Command):
a: ToPort
b: ToPort

def incoming(self) -> list[ToPort]:
return [self.a, self.b]

def num_out(self) -> int | None:
return 2

def op(self) -> Op:
return DummyOp(
sops.CustomOp(
parent=-1,
extension="arithmetic.int",
op_name="idivmod_u",
signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2),
args=[ARG_5, ARG_5],
)
)


def _validate(h: Hugr, mermaid: bool = False):
Expand Down Expand Up @@ -127,7 +155,7 @@ def test_tuple():
def test_multi_out():
h = Dfg([INT_T] * 2, [INT_T] * 2)
a, b = h.inputs()
a, b = h.add_op(DIV_OP, a, b)[:2]
a, b = h.add(DivMod(a, b))
h.set_outputs(a, b)

_validate(h.hugr)
Expand All @@ -136,7 +164,7 @@ def test_multi_out():
def test_insert():
h1 = Dfg.endo([BOOL_T])
(a1,) = h1.inputs()
nt = h1.add_op(NOT_OP, a1)
nt = h1.add(Not(a1))
h1.set_outputs(nt)

assert len(h1.hugr) == 4
Expand All @@ -149,7 +177,7 @@ def test_insert():
def test_insert_nested():
h1 = Dfg.endo([BOOL_T])
(a1,) = h1.inputs()
nt = h1.add_op(NOT_OP, a1)
nt = h1.add(Not(a1))
h1.set_outputs(nt)

h = Dfg.endo([BOOL_T])
Expand All @@ -163,7 +191,7 @@ def test_insert_nested():
def test_build_nested():
def _nested_nop(dfg: Dfg):
(a1,) = dfg.inputs()
nt = dfg.add_op(NOT_OP, a1)
nt = dfg.add(Not(a1))
dfg.set_outputs(nt)

h = Dfg.endo([BOOL_T])
Expand Down

0 comments on commit de22b5a

Please sign in to comment.