Skip to content

Commit

Permalink
Support for sampling last k neighbors as temporal sampling strategy (
Browse files Browse the repository at this point in the history
…#5576)

* update

* update

* changelog

* update

* update

* update
  • Loading branch information
rusty1s committed Sep 30, 2022
1 parent e998348 commit 00c3a5d
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `temporal_strategy` option to `neighbor_sample` ([#5576](https://github.com/pyg-team/pyg-lib/pull/5576))
- Added `torch_geometric.sampler` package to docs ([#5563](https://github.com/pyg-team/pytorch_geometric/pull/5563))
- Added the `DGraphFin` dynamic graph dataset ([#5504](https://github.com/pyg-team/pytorch_geometric/pull/5504))
- Added `dropout_edge` augmentation that randomly drops edges from a graph - the usage of `dropout_adj` is now deprecated ([#5495](https://github.com/pyg-team/pytorch_geometric/pull/5495))
Expand Down
9 changes: 9 additions & 0 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ class LinkNeighborLoader(LinkLoader):
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
temporal_strategy (string, optional): The sampling strategy when using
temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
If set to :obj:`"uniform"`, will sample uniformly across neighbors
that fulfill temporal constraints.
If set to :obj:`"last"`, will sample the last `num_neighbors` that
fulfill temporal constraints.
(default: :obj:`"uniform"`)
neg_sampling_ratio (float, optional): The ratio of sampled negative
edges to the number of positive edges.
If :obj:`neg_sampling_ratio > 0` and in case :obj:`edge_label`
Expand Down Expand Up @@ -150,6 +157,7 @@ def __init__(
edge_label_time: OptTensor = None,
replace: bool = False,
directed: bool = True,
temporal_strategy: str = 'uniform',
neg_sampling_ratio: float = 0.0,
time_attr: Optional[str] = None,
transform: Callable = None,
Expand Down Expand Up @@ -180,6 +188,7 @@ def __init__(
num_neighbors=num_neighbors,
replace=replace,
directed=directed,
temporal_strategy=temporal_strategy,
input_type=edge_type,
time_attr=time_attr,
is_sorted=is_sorted,
Expand Down
9 changes: 9 additions & 0 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ class NeighborLoader(NodeLoader):
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
temporal_strategy (string, optional): The sampling strategy when using
temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
If set to :obj:`"uniform"`, will sample uniformly across neighbors
that fulfill temporal constraints.
If set to :obj:`"last"`, will sample the last `num_neighbors` that
fulfill temporal constraints.
(default: :obj:`"uniform"`)
time_attr (str, optional): The name of the attribute that denotes
timestamps for the nodes in the graph.
If set, temporal sampling will be used such that neighbors are
Expand Down Expand Up @@ -159,6 +166,7 @@ def __init__(
input_nodes: InputNodes = None,
replace: bool = False,
directed: bool = True,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
transform: Callable = None,
is_sorted: bool = False,
Expand All @@ -177,6 +185,7 @@ def __init__(
num_neighbors=num_neighbors,
replace=replace,
directed=directed,
temporal_strategy=temporal_strategy,
input_type=node_type,
time_attr=time_attr,
is_sorted=is_sorted,
Expand Down
5 changes: 5 additions & 0 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
num_neighbors: NumNeighbors,
replace: bool = False,
directed: bool = True,
temporal_strategy: str = 'uniform',
input_type: Optional[Any] = None,
time_attr: Optional[str] = None,
is_sorted: bool = False,
Expand All @@ -47,6 +48,7 @@ def __init__(
self.num_neighbors = num_neighbors
self.replace = replace
self.directed = directed
self.temporal_strategy = temporal_strategy
self.node_time = None
self.input_type = input_type

Expand Down Expand Up @@ -237,6 +239,7 @@ def _sample(
self.replace,
self.directed,
disjoint,
self.temporal_strategy,
True, # return_edge_id
)
row, col, node, edge, batch = out + (None, )
Expand All @@ -259,6 +262,7 @@ def _sample(
self.directed,
)
else:
assert self.temporal_strategy == 'uniform'
fn = torch.ops.torch_sparse.hetero_temporal_neighbor_sample
out = fn(
self.node_types,
Expand Down Expand Up @@ -297,6 +301,7 @@ def _sample(
self.replace,
self.directed,
disjoint,
self.temporal_strategy,
True, # return_edge_id
)
row, col, node, edge, batch = out + (None, )
Expand Down
10 changes: 8 additions & 2 deletions torch_geometric/sampler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,14 @@ def to_csc(

elif hasattr(data, 'adj_t'):
if src_node_time is not None:
raise NotImplementedError("Temporal sampling via 'SparseTensor' "
"format not yet supported")
# TODO (matthias) This only works when instantiating a
# `SparseTensor` with `is_sorted=True`. Otherwise, the
# `SparseTensor` will by default re-sort the neighbors according to
# column index.
# As such, we probably want to consider re-adding error:
# raise NotImplementedError("Temporal sampling via 'SparseTensor' "
# "format not yet supported")
pass
colptr, row, _ = data.adj_t.csr()

elif data.edge_index is not None:
Expand Down

0 comments on commit 00c3a5d

Please sign in to comment.