Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Aug 10, 2023
1 parent d204eea commit 917c1ee
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
9 changes: 4 additions & 5 deletions test/loader/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from torch_geometric.loader.utils import index_select


@pytest.mark.parametrize("", [1, 2, 3])
def test_index_select():
x = torch.randn(3, 5)
index = torch.tensor([0, 2])
assert torch.all(index_select(x, index) == x[index])
assert torch.all(index_select(x, index, dim=-1) == x[..., index])
assert torch.equal(index_select(x, index), x[index])
assert torch.equal(index_select(x, index, dim=-1), x[..., index])


def test_index_select_invalid_value():
with pytest.raises(ValueError, "Encountered invalid feature tensor"):
def test_index_select_out_of_range():
with pytest.raises(IndexError, match="out of range"):
index_select(torch.randn(3, 5), torch.tensor([0, 2, 3]))
23 changes: 9 additions & 14 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,23 @@ def index_select(
index: Tensor,
dim: int = 0,
) -> Tensor:
"""Selects and returns features from :obj:`value` according to :obj:`index`.
r"""Indexes the :obj:`value` tensor along dimension :obj:`dim` using the
entries in :obj:`index`.
Arguments:
value (Tensor or np.ndarray): The feature tensor.
index (Tensor): The indices of the features to select.
Args:
value (torch.Tensor or np.ndarray): The input tensor.
index (torch.Tensor): The 1-D tensor containing the indices to index.
dim (int, optional): The dimension in which to index.
(default: :obj:`0`)
Returns:
:obj:`Tensor`: The selected features.
Raises:
:class:`ValueError`: If :obj:`value` is neither a :obj:`Tensor` nor a
:obj:`np.ndarray`.
.. warning::
:obj:`index` is casted to a :obj:`torch.int64` tensor internally, as
PyTorch currently only supports indexing via :obj:`torch.int64`.
See https://github.com/pytorch/pytorch/issues/61819.
`PyTorch currently only supports indexing
<https://github.com/pytorch/pytorch/issues/61819>`_ via
:obj:`torch.int64`.
"""
# PyTorch currently only supports indexing via `torch.int64` :(
# PyTorch currently only supports indexing via `torch.int64`:
# https://github.com/pytorch/pytorch/issues/61819
index = index.to(torch.int64)

Expand Down

0 comments on commit 917c1ee

Please sign in to comment.