Skip to content

Commit

Permalink
Merge pull request #19 from AIRI-Institute/misc
Browse files Browse the repository at this point in the history
Misc
  • Loading branch information
BerAnton committed Mar 19, 2024
2 parents 3ca2077 + b6f5b45 commit 89be191
Show file tree
Hide file tree
Showing 18 changed files with 229 additions and 148 deletions.
2 changes: 1 addition & 1 deletion config/equiformer_v2_oc20.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Global variables
name: Equiformer_v2_OC20
name: Equiformer_v2
dataset_name: dataset_train_2k
max_steps: 1000000
warmup_steps: 0
Expand Down
26 changes: 26 additions & 0 deletions config/gemnet-oc_predict.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Global variables
name: GemNet-OC
dataset_name: null # filename w/o extension from root directory, see datamodule config
max_steps: 1000000
warmup_steps: 0
job_type: predict
pretrained: False
ckpt_path: null # path to checkpoint for training resume or test run

# configs
defaults:
- _self_
- datamodule: nablaDFT_pyg_test.yaml # dataset config
- model: gemnet-oc.yaml # model config
- callbacks: default.yaml # pl callbacks config
- loggers: wandb.yaml # pl loggers config
- trainer: test.yaml # trainer config

# need this to set working dir as current dir
hydra:
output_subdir: null
run:
dir: .
original_work_dir: ${hydra:runtime.cwd}

seed: 23
26 changes: 26 additions & 0 deletions config/gemnet-oc_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Global variables
name: GemNet-OC
dataset_name: dataset_test_conformations_2k
max_steps: 1000000
warmup_steps: 0
job_type: test
pretrained: False
ckpt_path: null # path to checkpoint for training resume or test run

# configs
defaults:
- _self_
- datamodule: nablaDFT_pyg_test.yaml # dataset config
- model: gemnet-oc.yaml # model config
- callbacks: default.yaml # pl callbacks config
- loggers: wandb.yaml # pl loggers config
- trainer: test.yaml # trainer config

# need this to set working dir as current dir
hydra:
output_subdir: null
run:
dir: .
original_work_dir: ${hydra:runtime.cwd}

seed: 23
2 changes: 1 addition & 1 deletion config/model/graphormer3d-half.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ optimizer:
lr: 3e-4

lr_scheduler:
_target_: nablaDFT.graphormer.schedulers.get_linear_schedule_with_warmup
_target_: nablaDFT.schedulers.get_linear_schedule_with_warmup
_partial_: true
num_warmup_steps: ${warmup_steps}
num_training_steps: ${max_steps}
Expand Down
25 changes: 25 additions & 0 deletions config/schnet_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Global variables
name: SchNet
dataset_name: dataset_test_conformations_2k
max_steps: 1000000
job_type: test
pretrained: False
ckpt_path: null # path to checkpoint for training resume or test run

# configs
defaults:
- _self_
- datamodule: nablaDFT_ase_test.yaml # dataset config
- model: schnet.yaml # model config
- callbacks: callbacks_spk.yaml # pl callbacks config
- loggers: wandb.yaml # pl loggers config
- trainer: test.yaml # trainer config

# need this to set working dir as current dir
hydra:
output_subdir: null
run:
dir: .
original_work_dir: ${hydra:runtime.cwd}

seed: 23
3 changes: 2 additions & 1 deletion config/trainer/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ accelerator: "gpu"
devices: [0]
strategy:
_target_: pytorch_lightning.strategies.ddp.DDPStrategy

# QHNet has unused params, uncomment line for train
# find_unused_parameters: True
max_steps: ${max_steps}

# example of additional arguments for trainer
Expand Down
66 changes: 66 additions & 0 deletions nablaDFT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
## Run configuration
Overall pipeline heavily inspired by https://github.com/ashleve/lightning-hydra-template
Run command from repo root:
```bash
python run.py --config-name <config_name>.yaml
```
where config_name is one of task yaml-configs from `config` directory.
Type of run defined with `job_type` parameter and it must be one of:
- train
- test
- predict
- optimize

Each config consists of global variables and section with other trainer parts:
- datamodule
- model
- callbacks
- loggers
- trainer

### Datamodule

Datamodule config defines type of dataset (ASE, Hamiltonian, PyG), dataset root path, batch size, train/val ratio for training job.
Example configurations for ASE and PyG dataset types are stored in `configs/datamodule/`.

### Model

Model config defines hyperparameters for chosen model architecture together with metrics and losses. See examples from `config/models/`.
To add another model you need to define `LightningModule` (see examples from `nablaDFT/`) and pass model config to run configuration.

### Callbacks

By default we use `ModelCheckpoint` and `EarlyStopping` callbacks, you may add desired callbacks with `config/callbacks/default.yaml`.

### Loggers

By default we use solely `WandbLogger`, you may add other loggers.

### Trainer

Defines additional parameters for trainer like accelerator, strategy and devices.

## Train
Example: `config/gemnet-oc.yaml`
Basically you may need to change `dataset_name` parameter in order to pick one of nablaDFT training split. Available training splits could be found in `nablaDFT/links/energy_databases.json`.
In case of training resume, just specify checkpoint path in `ckpt_path` parameter.

## Test

Example: `config/gemnet-oc_test.yaml`
Same as for training run, you may need to change `dataset_name` parameter for desired test split.
To reproduce test results for pretrained on biggest training dataset split (100k) set `pretrained` parameter to `True` with ckpt_path to `null`. Otherwise, specify path to checkpoint with pretrained model in `ckpt_path`.

## Predict

Exmaple: `config/gemnet-oc_predict.yaml`
To obtain model preidctions for another datasets use `job_type: predict`. Specify dataset path relative to `root` or `datapath` parameter from datamodule config.
Predictions will be stored in `predictions/{model_name}_{dataset_name}.pt`

## Optimize

Examples: `config/gemnet-oc_optim.yaml`, `schnet_optim.yaml`
`job_type: optimize` stands for molecule geometry optimization with pretrained model.
Molecules from `input_db` parameter will be optimized using pretrained checkpoint from `ckpt_path`.
Currently only LBFGS optimizer supported.
Results will be saved at `output_db` parameter path.
93 changes: 33 additions & 60 deletions nablaDFT/dataset/pyg_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import torch
from ase.db import connect
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.data import InMemoryDataset, Data, Dataset

import nablaDFT
from .hamiltonian_dataset import HamiltonianDatabase
Expand Down Expand Up @@ -70,7 +70,7 @@ def get(self, idx):
return super(PyGNablaDFT, self).get(idx - self.offsets[data_idx])

def download(self) -> None:
with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json", "r") as f:
with open(nablaDFT.__path__[0] + "/links/energy_databases.json", "r") as f:
data = json.load(f)
url = data[f"{self.split}_databases"][self.dataset_name]
file_size = get_file_size(url)
Expand All @@ -84,10 +84,7 @@ def process(self) -> None:
z = torch.from_numpy(db_row.numbers).long()
positions = torch.from_numpy(db_row.positions).float()
y = torch.from_numpy(np.array(db_row.data["energy"])).float()
# TODO: temp workaround for dataset w/o forces
forces = db_row.data.get("forces", None)
if forces is not None:
forces = torch.from_numpy(np.array(forces)).float()
forces = torch.from_numpy(np.array(db_row.data["forces"])).float()
samples.append(Data(z=z, pos=positions, y=y, forces=forces))

if self.pre_filter is not None:
Expand All @@ -101,8 +98,7 @@ def process(self) -> None:
logger.info(f"Saved processed dataset: {self.processed_paths[0]}")


# TODO: move this to OnDiskDataset
class PyGHamiltonianNablaDFT(InMemoryDataset):
class PyGHamiltonianNablaDFT(Dataset):
"""Pytorch Geometric dataset for NablaDFT Hamiltonian database.
Args:
Expand All @@ -114,7 +110,7 @@ class PyGHamiltonianNablaDFT(InMemoryDataset):
- include_core (bool): if True, retrieves core Hamiltonian matrices from database.
- dtype (torch.dtype): defines torch.dtype for energy, positions, forces tensors.
- transform (Callable): callable data transform, called on every access to element.
- pre_transform (Callable): callable data transform, called during process() for every element.
- pre_transform (Callable): callable data transform, called on every access to element.
Note:
Hamiltonian matrix for each molecule has different shape. PyTorch Geometric tries to concatenate
each torch.Tensor in batch, so in order to make batch from data we leave all hamiltonian matrices
Expand Down Expand Up @@ -155,26 +151,37 @@ def __init__(
super(PyGHamiltonianNablaDFT, self).__init__(datapath, transform, pre_transform)

self.max_orbitals = self._get_max_orbitals(datapath, dataset_name)
for path in self.processed_paths:
data, slices = torch.load(path)
self.data_all.append(data)
self.slices_all.append(slices)
self.offsets.append(
len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1]
)
self.db = HamiltonianDatabase(self.raw_paths[0])

def len(self) -> int:
return sum(
len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all
)
return len(self.db)

def get(self, idx):
data_idx = 0
while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]:
data_idx += 1
self.data = self.data_all[data_idx]
self.slices = self.slices_all[data_idx]
return super(PyGHamiltonianNablaDFT, self).get(idx - self.offsets[data_idx])
data = self.db[idx]
z = torch.tensor(data[0]).long()
positions = torch.tensor(data[1]).to(self.dtype)
# see notes
hamiltonian = data[4]
if self.include_overlap:
overlap = data[5]
else:
overlap = None
if self.include_core:
core = data[6]
else:
core = None
y = torch.from_numpy(data[2]).to(self.dtype)
forces = torch.from_numpy(data[3]).to(self.dtype)
data = Data(
z=z, pos=positions,
y=y, forces=forces,
hamiltonian=hamiltonian,
overlap=overlap,
core=core,
)
if self.pre_transform is not None:
data = self.pre_transform(data)
return data

def download(self) -> None:
with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f:
Expand All @@ -185,41 +192,7 @@ def download(self) -> None:
request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t))

def process(self) -> None:
database = HamiltonianDatabase(self.raw_paths[0])
samples = []
for idx in tqdm(range(len(database)), total=len(database)):
data = database[idx]
z = torch.tensor(data[0]).long()
positions = torch.tensor(data[1]).to(self.dtype)
# see notes
hamiltonian = data[4]
if self.include_overlap:
overlap = data[5]
else:
overlap = None
if self.include_core:
core = data[6]
else:
core = None
y = torch.from_numpy(data[2]).to(self.dtype)
forces = torch.from_numpy(data[3]).to(self.dtype)
samples.append(Data(
z=z, pos=positions,
y=y, forces=forces,
hamiltonian=hamiltonian,
overlap=overlap,
core=core,
))

if self.pre_filter is not None:
samples = [data for data in samples if self.pre_filter(data)]

if self.pre_transform is not None:
samples = [self.pre_transform(data) for data in samples]

data, slices = self.collate(samples)
torch.save((data, slices), self.processed_paths[0])
logger.info(f"Saved processed dataset: {self.processed_paths[0]}")
pass

def _get_max_orbitals(self, datapath, dataset_name):
db_path = os.path.join(datapath, "raw/" + dataset_name + self.db_suffix)
Expand Down
8 changes: 1 addition & 7 deletions nablaDFT/dimenetplusplus/dimenetplusplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,7 @@ def step(
) -> Union[Tuple[Any, Dict], Any]:
predictions_energy, predictions_forces = self.forward(batch)
loss_energy = self.loss(predictions_energy, batch.y)
# TODO: temp workaround
if hasattr(batch, "forces"):
loss_forces = self.loss(predictions_forces, batch.forces)
else:
loss_forces = torch.zeros(1).to(self.device)
predictions_forces = torch.zeros(1).to(self.device)
forces = torch.zeros(1).to(self.device)
loss_forces = self.loss(predictions_forces, batch.forces)
loss = self.loss_forces_coef * loss_forces + self.loss_energy_coef * loss_energy
if calculate_metrics:
preds = {"energy": predictions_energy, "forces": predictions_forces}
Expand Down
6 changes: 1 addition & 5 deletions nablaDFT/equiformer_v2/equiformer_v2_oc20.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,11 +698,7 @@ def step(self, batch, calculate_metrics: bool = False):
y = batch.y
# make dense batch from PyG batch
energy_out, forces_out = self.net(batch)
# TODO: temp workaround
if hasattr(batch, "forces"):
forces = batch.forces
else:
forces = forces_out.clone()
forces = batch.forces
preds = {"energy": energy_out, "forces": forces_out}
target = {"energy": y, "forces": forces}
loss = self._calculate_loss(preds, target)
Expand Down
6 changes: 1 addition & 5 deletions nablaDFT/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,11 +1074,7 @@ def step(self, batch, calculate_metrics: bool = False):
y = batch.y
# make dense batch from PyG batch
energy_out, forces_out = self.net(batch)
# TODO: temp workaround
if hasattr(batch, "forces"):
forces = batch.forces
else:
forces = forces_out.clone()
forces = batch.forces
preds = {"energy": energy_out, "forces": forces_out}
target = {"energy": y, "forces": forces}
loss = self._calculate_loss(preds, target)
Expand Down
5 changes: 1 addition & 4 deletions nablaDFT/gemnet_oc/gemnet_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,10 +1432,7 @@ def forward(self, data: Data):
def step(self, batch, calculate_metrics: bool = False):
energy_out, forces_out = self.net(batch)
# TODO: temp workaround
if hasattr(batch, "forces"):
forces = batch.forces
else:
forces = forces_out.clone()
forces = batch.forces
preds = {"energy": energy_out, "forces": forces_out}
target = {"energy": batch.y, "forces": forces}
loss = self._calculate_loss(preds, target)
Expand Down
16 changes: 5 additions & 11 deletions nablaDFT/graphormer/graphormer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,17 +380,11 @@ def step(
y = batch.y
energy_out, forces_out, mask_out = self(batch)
loss_energy = self.loss(energy_out, y)
# TODO: temp workaround for datasets w/o forces
if hasattr(batch, "forces"):
forces, mask_forces = to_dense_batch(
batch.forces, batch.batch, batch_size=bsz
)
masked_forces_out = forces_out * mask_forces.unsqueeze(-1)
loss_forces = self.loss(masked_forces_out, forces)
else:
loss_forces = torch.zeros(1).to(self.device)
masked_forces_out = torch.zeros(1).to(self.device)
forces = torch.zeros(1).to(self.device)
forces, mask_forces = to_dense_batch(
batch.forces, batch.batch, batch_size=bsz
)
masked_forces_out = forces_out * mask_forces.unsqueeze(-1)
loss_forces = self.loss(masked_forces_out, forces)
loss = self.loss_forces_coef * loss_forces + self.loss_energy_coef * loss_energy
if calculate_metrics:
preds = {"energy": energy_out, "forces": masked_forces_out}
Expand Down
Loading

0 comments on commit 89be191

Please sign in to comment.