Skip to content

Commit

Permalink
Merge branch 'master' into test/loader-utils2
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Aug 10, 2023
2 parents 917c1ee + b965d76 commit 19c55e2
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 115 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added TorchScript support inside `BasicGNN` models ([#7865](https://github.com/pyg-team/pytorch_geometric/pull/7865))
- Added a `batch_size` argument to `unbatch` functionalities ([#7851](https://github.com/pyg-team/pytorch_geometric/pull/7851))
- Added a distributed example using `graphlearn-for-pytorch` ([#7402](https://github.com/pyg-team/pytorch_geometric/pull/7402))
- Integrate `neg_sampling_ratio` into `TemporalDataLoader` ([#7644](https://github.com/pyg-team/pytorch_geometric/pull/7644))
Expand Down
10 changes: 10 additions & 0 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ def test_edge_cnn(out_dim, dropout, act, norm, jk):
assert model(x, edge_index).size() == (3, out_channels)


def test_jittable():
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = GCN(8, 16, num_layers=2).jittable()
model = torch.jit.script(model)

assert model(x, edge_index).size() == (3, 16)


@pytest.mark.parametrize('out_dim', out_dims)
@pytest.mark.parametrize('jk', jks)
def test_one_layer_gnn(out_dim, jk):
Expand Down
26 changes: 22 additions & 4 deletions test/nn/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def test_summary_basic(gcn):
| ├─(convs)ModuleList | -- | -- | 1,072 |
| │ └─(0)GCNConv | [100, 32], [2, 20] | [100, 16] | 528 |
| │ └─(1)GCNConv | [100, 16], [2, 20] | [100, 32] | 544 |
| ├─(norms)ModuleList | -- | -- | -- |
| │ └─(0)Identity | [100, 16] | [100, 16] | -- |
| │ └─(1)Identity | -- | -- | -- |
+---------------------+--------------------+----------------+----------+
"""
assert summary(gcn['model'], gcn['x'], gcn['edge_index']) == expected[1:-1]
Expand All @@ -81,6 +84,9 @@ def test_summary_with_sparse_tensor(gcn):
| ├─(convs)ModuleList | -- | -- | 1,072 |
| │ └─(0)GCNConv | [100, 32], [100, 100] | [100, 16] | 528 |
| │ └─(1)GCNConv | [100, 16], [100, 100] | [100, 32] | 544 |
| ├─(norms)ModuleList | -- | -- | -- |
| │ └─(0)Identity | [100, 16] | [100, 16] | -- |
| │ └─(1)Identity | -- | -- | -- |
+---------------------+-----------------------+----------------+----------+
"""
assert summary(gcn['model'], gcn['x'], gcn['adj_t']) == expected[1:-1]
Expand All @@ -96,10 +102,15 @@ def test_summary_with_max_depth(gcn):
| ├─(dropout)Dropout | [100, 16] | [100, 16] | -- |
| ├─(act)ReLU | [100, 16] | [100, 16] | -- |
| ├─(convs)ModuleList | -- | -- | 1,072 |
| ├─(norms)ModuleList | -- | -- | -- |
+---------------------+--------------------+----------------+----------+
"""
assert summary(gcn['model'], gcn['x'], gcn['edge_index'],
max_depth=1) == expected[1:-1]
assert summary(
gcn['model'],
gcn['x'],
gcn['edge_index'],
max_depth=1,
) == expected[1:-1]


@withPackage('tabulate')
Expand All @@ -118,10 +129,17 @@ def test_summary_with_leaf_module(gcn):
| │ └─(1)GCNConv | [100, 16], [2, 20] | [100, 32] | 544 |
| │ │ └─(aggr_module)SumAggregation | [120, 32], [120] | [100, 32] | -- |
| │ │ └─(lin)Linear | [100, 16] | [100, 32] | 512 |
| ├─(norms)ModuleList | -- | -- | -- |
| │ └─(0)Identity | [100, 16] | [100, 16] | -- |
| │ └─(1)Identity | -- | -- | -- |
+-----------------------------------------+--------------------+----------------+----------+
"""
assert summary(gcn['model'], gcn['x'], gcn['edge_index'],
leaf_module=None) == expected[13:-1]
assert summary(
gcn['model'],
gcn['x'],
gcn['edge_index'],
leaf_module=None,
) == expected[13:-1]


@withPackage('tabulate')
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,8 @@ def register_edge_update_forward_hook(self,
@torch.jit.unused
def jittable(self, typing: Optional[str] = None) -> 'MessagePassing':
r"""Analyzes the :class:`MessagePassing` instance and produces a new
jittable module.
jittable module that can be used in combination with
:meth:`torch.jit.script`.
Args:
typing (str, optional): If given, will generate a concrete instance
Expand Down
Loading

0 comments on commit 19c55e2

Please sign in to comment.