Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DimeNet++ implementation #4432

Merged
merged 14 commits into from
May 24, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671))
- Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672), [#4676](https://github.com/pyg-team/pytorch_geometric/pull/4676))
Expand Down
18 changes: 13 additions & 5 deletions examples/qm9_pretrained_dimenet.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
import argparse
import os.path as osp

import torch

from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DimeNet
from torch_geometric.nn import DimeNet, DimeNetPlusPlus

parser = argparse.ArgumentParser()
parser.add_argument('--use_dimenet_plus_plus', action='store_true')
args = parser.parse_args()

Model = DimeNetPlusPlus if args.use_dimenet_plus_plus else DimeNet

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')
dataset = QM9(path)

# DimeNet uses the atomization energy for targets U0, U, H, and G.
# DimeNet uses the atomization energy for targets U0, U, H, and G, i.e.:
# 7 -> 12, 8 -> 13, 9 -> 14, 10 -> 15
idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 12, 13, 14, 15, 11])
dataset.data.y = dataset.data.y[:, idx]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for target in range(12):
# Skip target \delta\epsilon, since it can be computed via
# \epsilon_{LUMO} - \epsilon_{HOMO}.
# \epsilon_{LUMO} - \epsilon_{HOMO}:
if target == 4:
continue

model, datasets = DimeNet.from_qm9_pretrained(path, dataset, target)
model, datasets = Model.from_qm9_pretrained(path, dataset, target)
train_dataset, val_dataset, test_dataset = datasets

model = model.to(device)
Expand All @@ -37,7 +45,7 @@

mae = torch.cat(maes, dim=0)

# Report meV instead of eV.
# Report meV instead of eV:
mae = 1000 * mae if target in [2, 3, 4, 6, 7, 8, 9, 10] else mae

print(f'Target: {target:02d}, MAE: {mae.mean():.5f} ± {mae.std():.5f}')
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
full_requires = graphgym_requires + [
'h5py',
'numba',
'sympy',
'pandas',
'captum',
'rdflib',
Expand Down
41 changes: 41 additions & 0 deletions test/nn/models/test_dimenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import DimeNetPlusPlus
from torch_geometric.testing import onlyFullTest


@onlyFullTest
def test_dimenet_plus_plus():
data = Data(
z=torch.randint(1, 10, (20, )),
pos=torch.randn(20, 3),
y=torch.tensor([1.]),
)

model = DimeNetPlusPlus(
hidden_channels=5,
out_channels=1,
num_blocks=5,
out_emb_channels=3,
int_emb_size=5,
basis_emb_size=5,
num_spherical=5,
num_radial=5,
num_before_skip=2,
num_after_skip=2,
)

with torch.no_grad():
out = model(data.z, data.pos)
assert out.size() == (1, )

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
for i in range(100):
optimizer.zero_grad()
out = model(data.z, data.pos)
loss = F.l1_loss(out, data.y)
loss.backward()
optimizer.step()
assert loss < 1
3 changes: 2 additions & 1 deletion torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .re_net import RENet
from .graph_unet import GraphUNet
from .schnet import SchNet
from .dimenet import DimeNet
from .dimenet import DimeNet, DimeNetPlusPlus
from .explainer import Explainer, to_captum
from .gnn_explainer import GNNExplainer
from .metapath2vec import MetaPath2Vec
Expand Down Expand Up @@ -43,6 +43,7 @@
'GraphUNet',
'SchNet',
'DimeNet',
'DimeNetPlusPlus',
'Explainer',
'to_captum',
'GNNExplainer',
Expand Down
Loading