Skip to content

Commit

Permalink
test stable indices
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed May 22, 2024
1 parent dd16870 commit 1cd9598
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
13 changes: 9 additions & 4 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __iter__(self):
return iter(self._nodes)

def __len__(self) -> int:
return len(self._nodes) - len(self._free_nodes)
return self.num_nodes()

def add_node(
self,
Expand All @@ -155,15 +155,17 @@ def add_node(
self._nodes.append(node_data)
return node

def delete_node(self, node: Node) -> None:
def delete_node(self, node: Node) -> NodeData | None:
for offset in range(self.num_in_ports(node)):
self._links.delete_right(node.inp(offset))
for offset in range(self.num_out_ports(node)):
self._links.delete_left(node.out(offset))
self._nodes[node.idx] = None

weight, self._nodes[node.idx] = self._nodes[node.idx], None
self._free_nodes.append(node)
return weight

def add_link(self, src: OutPort, dst: InPort, ty: Type | None = None) -> None:
def add_link(self, src: OutPort, dst: InPort) -> None:
src = _unused_sub_offset(src, self._links)
dst = _unused_sub_offset(dst, self._links)
if self._links.get_left(dst) is not None:
Expand All @@ -173,6 +175,9 @@ def add_link(self, src: OutPort, dst: InPort, ty: Type | None = None) -> None:
def delete_link(self, src: OutPort, dst: InPort) -> None:
self._links.delete_left(src)

def num_nodes(self) -> int:
return len(self._nodes) - len(self._free_nodes)

def num_in_ports(self, node: Node) -> int:
return len(self.in_ports(node))

Expand Down
35 changes: 34 additions & 1 deletion hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from hugr.hugr import Dfg, Hugr, DummyOp, Node
import hugr.serialization.tys as stys
import hugr.serialization.ops as sops
import pytest

BOOL_T = stys.Type(stys.SumType(stys.UnitSum(size=2)))
QB_T = stys.Type(stys.Qubit())
Expand Down Expand Up @@ -38,7 +39,6 @@


def _validate(h: Hugr, mermaid: bool = False):
# TODO point to built hugr binary
# cmd = ["cargo", "run", "--features", "cli", "--"]
cmd = ["./target/debug/hugr"]

Expand All @@ -47,6 +47,39 @@ def _validate(h: Hugr, mermaid: bool = False):
subprocess.run(cmd + ["-"], check=True, input=h.to_serial().to_json().encode())


def test_stable_indices():
h = Hugr(DummyOp(sops.DFG(parent=-1)))

nodes = [h.add_node(NOT_OP) for _ in range(3)]
assert len(h) == 4

h.add_link(nodes[0].out(0), nodes[1].inp(0))

assert h.num_out_ports(nodes[0]) == 1
assert h.num_in_ports(nodes[1]) == 1

assert h.delete_node(nodes[1]) is not None
assert h._nodes[nodes[1].idx] is None

assert len(h) == 3
assert len(h._nodes) == 4
assert h._free_nodes == [nodes[1]]

assert h.num_out_ports(nodes[0]) == 0
assert h.num_in_ports(nodes[1]) == 0

with pytest.raises(KeyError):
_ = h[nodes[1]]
with pytest.raises(KeyError):
_ = h[Node(46)]

new_n = h.add_node(NOT_OP)
assert new_n == nodes[1]

assert len(h) == 4
assert h._free_nodes == []


def test_simple_id():
h = Dfg.endo([QB_T] * 2)
a, b = h.inputs()
Expand Down

0 comments on commit 1cd9598

Please sign in to comment.