Skip to content

Commit

Permalink
Fix node_attributes shape in read_tu_data (#5441)
Browse files Browse the repository at this point in the history
* Fix node_attributes shape

* Update tu.py

* Update CHANGELOG.md

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
flandolfi and rusty1s committed Sep 14, 2022
1 parent 83d0f32 commit b3c0318
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Fixed a bug in `TUDataset` in which node features were wrongly constructed whenever `node_attributes` only hold a single feature (*e.g.*, in `PROTEINS`) ([#5441](https://github.com/pyg-team/pytorch_geometric/pull/5411))
- Breaking change: removed `num_neighbors` as an attribute of loader ([#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404))
- `ASAPooling` is now jittable ([#5395](https://github.com/pyg-team/pytorch_geometric/pull/5395))
- Updated unsupervised `GraphSAGE` example to leverage `LinkNeighborLoader` ([#5317](https://github.com/pyg-team/pytorch_geometric/pull/5317))
Expand Down
4 changes: 4 additions & 0 deletions torch_geometric/io/tu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def read_tu_data(folder, prefix):
node_attributes = torch.empty((batch.size(0), 0))
if 'node_attributes' in names:
node_attributes = read_file(folder, prefix, 'node_attributes')
if node_attributes.dim() == 1:
node_attributes = node_attributes.unsqueeze(-1)

node_labels = torch.empty((batch.size(0), 0))
if 'node_labels' in names:
Expand All @@ -41,6 +43,8 @@ def read_tu_data(folder, prefix):
edge_attributes = torch.empty((edge_index.size(1), 0))
if 'edge_attributes' in names:
edge_attributes = read_file(folder, prefix, 'edge_attributes')
if edge_attributes.dim() == 1:
edge_attributes = edge_attributes.unsqueeze(-1)

edge_labels = torch.empty((edge_index.size(1), 0))
if 'edge_labels' in names:
Expand Down

0 comments on commit b3c0318

Please sign in to comment.