Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DimeNet++ implementation #4432

Merged
merged 14 commits into from
May 24, 2022
Merged

DimeNet++ implementation #4432

merged 14 commits into from
May 24, 2022

Conversation

arunppsg
Copy link
Contributor

@arunppsg arunppsg commented Apr 7, 2022

In this PR, I implemented DimeNet++ by subclassing DimeNet. The InteractionPPBlock and OutputPPBlock was based on this implementation.

Also, I depreciated OutputBlock, thereby DimeNet model will also be using OutputPPBlock. OutputPPBlock allows up-projection and down-projection of embeddings, thereby removing information bottlenecks.

Fixes #4427

@codecov
Copy link

codecov bot commented Apr 7, 2022

Codecov Report

Merging #4432 (dd4790f) into master (5a6e826) will decrease coverage by 0.64%.
The diff coverage is 10.38%.

@@            Coverage Diff             @@
##           master    #4432      +/-   ##
==========================================
- Coverage   82.98%   82.33%   -0.65%     
==========================================
  Files         319      319              
  Lines       16968    17117     +149     
==========================================
+ Hits        14081    14094      +13     
- Misses       2887     3023     +136     
Impacted Files Coverage Δ
torch_geometric/nn/models/dimenet.py 14.58% <9.80%> (-3.09%) ⬇️
torch_geometric/nn/models/__init__.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5a6e826...dd4790f. Read the comment docs.

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Can we add a basic test for DimeNetPlusPlus? Also I am interested in whether it is possible to make use of pre-trained weights for this model as well (similar to what we do for DimeNet), see here.

@arunppsg
Copy link
Contributor Author

arunppsg commented Apr 8, 2022

Sure, I will look into using pre-trained weights and also add tests.

@arunppsg arunppsg marked this pull request as draft April 9, 2022 07:42
@arunppsg arunppsg force-pushed the dimenet_pp branch 2 times, most recently from 02d5798 to 79a6347 Compare April 12, 2022 05:56
@arunppsg
Copy link
Contributor Author

@rusty1s regarding tests for DimeNetPlusPlus, I have added a test for checking output size and another one test for checking whether the model is training well or not - test_overfit. The downside of test_overfit is that it is sometimes flaky but otherwise it is a good test for checking whether the model is able to train or not.

@arunppsg arunppsg marked this pull request as ready for review April 12, 2022 07:04
@arunppsg arunppsg force-pushed the dimenet_pp branch 2 times, most recently from 8b79778 to 91c63ec Compare April 12, 2022 15:08
@arunppsg arunppsg requested a review from rusty1s April 13, 2022 06:57
@rusty1s
Copy link
Member

rusty1s commented Apr 13, 2022

Thank you @arunppsg! I am currently on vacation, will have a final look in the next week.

@arunppsg
Copy link
Contributor Author

arunppsg commented May 1, 2022

The failures in testing are due to not having sympy in requirements. It is used by DimeNet++ and DimeNet for basis computation.

@rusty1s
Copy link
Member

rusty1s commented May 2, 2022

@arunppsg Can we add the necessary dependencies to full_install_requires, see here?

@arunppsg
Copy link
Contributor Author

@rusty1s added the necessary requirements.

@rusty1s rusty1s changed the title Dimenet++ implementation DimeNet++ implementation May 23, 2022
@rusty1s
Copy link
Member

rusty1s commented May 23, 2022

Thanks @arunppsg. I think I fixed the dependency issue during testing. I also added a --use_dimenet_plus_plus option to the QM9 script. However, the results of the pre-trained model are far off, e.g.:

Target: 00, MAE: 4.85522 ± 3.42503
Target: 01, MAE: 15.25113 ± 6.21239
Target: 02, MAE: 603.06146 ± 484.24606
Target: 03, MAE: 994.61578 ± 703.35645
Target: 05, MAE: 977.33118 ± 332.41879
Target: 06, MAE: 594.56433 ± 317.66690
Target: 07, MAE: 3095.07080 ± 1889.53760

Do you know what's the reason for this?

@arunppsg
Copy link
Contributor Author

arunppsg commented May 24, 2022

Thanks @arunppsg. I think I fixed the dependency issue during testing. I also added a --use_dimenet_plus_plus option to the QM9 script. However, the results of the pre-trained model are far off, e.g.:

Target: 00, MAE: 4.85522 ± 3.42503
Target: 01, MAE: 15.25113 ± 6.21239
Target: 02, MAE: 603.06146 ± 484.24606
Target: 03, MAE: 994.61578 ± 703.35645
Target: 05, MAE: 977.33118 ± 332.41879
Target: 06, MAE: 594.56433 ± 317.66690
Target: 07, MAE: 3095.07080 ± 1889.53760

Do you know what's the reason for this?

Not sure about it, the results are far off from the dimenet++ paper. Can we create an issue and track it separately?

@rusty1s
Copy link
Member

rusty1s commented May 24, 2022

This sounds good to me. I will try to look into this as well.

@rusty1s rusty1s merged commit a7e6be4 into pyg-team:master May 24, 2022
@arunppsg
Copy link
Contributor Author

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for DimeNet++ model
3 participants