diff --git a/hugr-py/src/hugr/hugr/node_port.py b/hugr-py/src/hugr/hugr/node_port.py index 556692aca..569043f5e 100644 --- a/hugr-py/src/hugr/hugr/node_port.py +++ b/hugr-py/src/hugr/hugr/node_port.py @@ -176,22 +176,22 @@ def _index( ) raise ValueError(msg) - start = self._normalize_index(start) - stop = self._normalize_index(stop, allow_eq_len=True) + start = self._normalize_index(start, allow_overflow=True) + stop = self._normalize_index(stop, allow_overflow=True) step = index.step or 1 return (self[i] for i in range(start, stop, step)) case tuple(xs): return (self[i] for i in xs) - def _normalize_index(self, index: int, allow_eq_len: bool = False) -> int: + def _normalize_index(self, index: int, allow_overflow: bool = False) -> int: """Given an index passed to `__getitem__`, normalize it to be within the range of output ports. Args: index: index to normalize. - allow_eq_len: whether to allow the index to be equal to the number of - output ports. + allow_overflow: whether to allow indices beyond the number of outputs. + If True, indices over `self._num_out_ports` will be truncated. Returns: Normalized index. @@ -202,9 +202,7 @@ def _normalize_index(self, index: int, allow_eq_len: bool = False) -> int: msg = f"Index {index} out of range" if self._num_out_ports is not None: - if index > self._num_out_ports: - raise IndexError(msg) - if index == self._num_out_ports and not allow_eq_len: + if index >= self._num_out_ports and not allow_overflow: raise IndexError(msg) if index < -self._num_out_ports: raise IndexError(msg) @@ -212,7 +210,9 @@ def _normalize_index(self, index: int, allow_eq_len: bool = False) -> int: if index < 0: raise IndexError(msg) - if index >= 0: + if index >= 0 and self._num_out_ports is not None: + return min(index, self._num_out_ports) + elif index >= 0: return index else: assert self._num_out_ports is not None diff --git a/hugr-py/tests/test_nodes.py b/hugr-py/tests/test_nodes.py index 16c3b259a..000f14423 100644 --- a/hugr-py/tests/test_nodes.py +++ b/hugr-py/tests/test_nodes.py @@ -30,8 +30,19 @@ def test_slices(): assert list(n[0:]) == all_ports assert list(n[:3]) == all_ports assert list(n[0:3]) == all_ports + assert list(n[0:999]) == all_ports + assert list(n[999:1000]) == [] assert list(n[-1:]) == [OutPort(n, 2)] assert list(n[-3:]) == all_ports with pytest.raises(IndexError, match="Index -4 out of range"): _ = n[-4:] + + n0 = Node(0, _num_out_ports=0) + assert list(n0) == [] + assert list(n0[:0]) == [] + assert list(n0[:10]) == [] + assert list(n0[0:0]) == [] + assert list(n0[0:]) == [] + assert list(n0[10:]) == [] + assert list(n0[:]) == []