Skip to content

Commit

Permalink
feat: inter graph + state order edge
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed May 28, 2024
1 parent 83a5b19 commit 600f18d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
39 changes: 23 additions & 16 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,6 @@ def next_sub_offset(self) -> Self:
_SI = _SubPort[InPort]


def _unused_sub_offset(sub_port: _SubPort[P], links: BiMap[_SO, _SI]) -> _SubPort[P]:
d: dict[_SO, _SI] | dict[_SI, _SO]
match sub_port.port:
case OutPort(_):
d = links.fwd
case InPort(_):
d = links.bck
while sub_port in d:
sub_port = sub_port.next_sub_offset()
return sub_port


@dataclass()
class Hugr(Mapping[Node, NodeData]):
root: Node
Expand Down Expand Up @@ -219,9 +207,21 @@ def delete_node(self, node: Node) -> NodeData | None:
self._free_nodes.append(node)
return weight

def _unused_sub_offset(self, port: P) -> _SubPort[P]:
d: dict[_SO, _SI] | dict[_SI, _SO]
match port:
case OutPort(_):
d = self._links.fwd
case InPort(_):
d = self._links.bck
sub_port = _SubPort(port)
while sub_port in d:
sub_port = sub_port.next_sub_offset()
return sub_port

def add_link(self, src: OutPort, dst: InPort) -> None:
src_sub = _unused_sub_offset(_SubPort(src), self._links)
dst_sub = _unused_sub_offset(_SubPort(dst), self._links)
src_sub = self._unused_sub_offset(src)
dst_sub = self._unused_sub_offset(dst)
# if self._links.get_left(dst_sub) is not None:
# dst = replace(dst, _sub_offset=dst._sub_offset + 1)
self._links.insert_left(src_sub, dst_sub)
Expand Down Expand Up @@ -403,11 +403,11 @@ def add_nested(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
ports: Iterable[Wire],
*args: Wire,
) -> Dfg:
dfg = Dfg(input_types, output_types)
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], ports)
self._wire_up(mapping[dfg.root], args)
dfg.hugr = self.hugr
dfg.input_node = mapping[dfg.input_node]
dfg.output_node = mapping[dfg.output_node]
Expand All @@ -429,6 +429,13 @@ def split_tuple(self, tys: Sequence[Type], port: Wire) -> list[OutPort]:

return [n.out(i) for i in range(len(tys))]

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
# breaks if further edges are added
self.hugr.add_link(
src.out(self.hugr.num_outgoing(src)), dst.inp(self.hugr.num_incoming(dst))
)

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
Expand Down
17 changes: 16 additions & 1 deletion hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,25 @@ def _nested_nop(dfg: Dfg):

h = Dfg.endo([BOOL_T])
(a,) = h.inputs()
nested = h.add_nested([BOOL_T], [BOOL_T], [a])
nested = h.add_nested([BOOL_T], [BOOL_T], a)

_nested_nop(nested)

h.set_outputs(nested.root)

_validate(h.hugr)


def test_build_inter_graph():
h = Dfg.endo([BOOL_T])
(a,) = h.inputs()
nested = h.add_nested([], [BOOL_T])

nt = nested.add(Not(a))
nested.set_outputs(nt)
# TODO a context manager could add this state order edge on
# exit by tracking parents of source nodes
h.add_state_order(h.input_node, nested.root)
h.set_outputs(nested.root)

_validate(h.hugr)

0 comments on commit 600f18d

Please sign in to comment.