Skip to content

Commit

Permalink
Merge pull request #26 from leynier/megexplainer-migration
Browse files Browse the repository at this point in the history
meg explainer migration
  • Loading branch information
MarioTheOne committed Jul 8, 2024
2 parents ccaaa90 + 2481a39 commit 8276cd7
Show file tree
Hide file tree
Showing 19 changed files with 1,103 additions and 621 deletions.
78 changes: 78 additions & 0 deletions config/leynier/meg-test-bbbp.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
{
"experiment": {
"scope": "meg_Test",
"parameters": {
"lock_release_tout": 120,
"propagate": []
}
},
"do-pairs": [
{
"dataset": {
"class": "src.dataset.dataset_base.Dataset",
"parameters": {
"generator": {
"class": "src.dataset.generators.bbbp.BBBP",
"parameters": {
"data_dir": "data/datasets/bbbp/",
"data_file_name": "BBBP-reduced.csv",
"data_label_name": "p_np"
}
},
"n_splits": 2
}
},
"oracle": {
"class": "src.oracle.nn.torch.OracleTorch",
"parameters": {
"epochs": 50,
"batch_size": 2,
"optimizer": {
"class": "torch.optim.RMSprop",
"parameters": {
"lr": 0.001,
"momentum": 0.5
}
},
"loss_fn": {
"class": "torch.nn.CrossEntropyLoss",
"parameters": {
"reduction": "mean"
}
},
"model": {
"class": "src.oracle.nn.gcn.DownstreamGCN",
"parameters": {
"num_conv_layers": 2,
"num_dense_layers": 1,
"conv_booster": 2,
"linear_decay": 1.8
}
}
}
}
}
],
"explainers": [
{
"class": "src.explainer.rl.meg.MEGExplainer",
"parameters": {
"num_input": 1024,
"env": {
"class": "src.explainer.rl.meg_utils.environments.bbbp_env.BBBPEnvironment",
"parameters": {}
},
"action_encoder": {
"class": "src.explainer.rl.meg_utils.utils.encoders.MorganCountFingerprintActionEncoder",
"parameters": {}
}
}
},
{
"class": "src.explainer.search.dces.DCESExplainer",
"parameters": {}
}
],
"compose_mes": "config/snippets/default_metrics.json",
"compose_strs": "config/snippets/default_store_paths.json"
}
5 changes: 5 additions & 0 deletions data/datasets/bbbp/BBBP-reduced.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
num,name,p_np,smiles
1,Propanolol,1,[Cl].CC(C)NCC(O)COc1cccc2ccccc12
2,Terbutylchlorambucil,1,C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl
3,SKF-93619,0,c1cc2c(cc(CC3=CNC(=NC3=O)NCCSCc3oc(cc3)CN(C)C)cc2)cc1
4,M2L-663581,0,OCC(C)(O)c1onc(c2ncn3c2CN(C)C(c4c3cccc4Cl)=O)n1
14 changes: 14 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,21 @@
from src.utils.context import Context
import sys

try:
from dotenv import load_dotenv

load_dotenv()
except Exception:
pass

if __name__ == "__main__":
if len(sys.argv) < 2:
# If no arguments are passed, try to find GRETEL_CONFIG_FILE in the environment
if "GRETEL_CONFIG_FILE" in os.environ:
sys.argv.append(os.environ["GRETEL_CONFIG_FILE"])
else:
print("Usage: python main.py <config_file> [run_number]")
sys.exit(1)
print(f"Generating context for: {sys.argv[1]}")
context = Context.get_context(sys.argv[1])
context.run_number = int(sys.argv[2]) if len(sys.argv) == 3 else -1
Expand Down
10 changes: 6 additions & 4 deletions src/dataset/generators/mol_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import numpy as np
import pandas as pd
import networkx as nx
from rdkit import Chem
from rdkit.Chem import MolFromSmiles as smi2mol
from rdkit.Chem import MolToSmiles as mol2smi

from src.dataset.generators.base import Generator
from src.dataset.instances.graph import GraphInstance
from src.explainer.rl.meg_utils.utils.molecular_instance import (
MolecularInstance,
)


class MolGenerator(Generator):

Expand Down Expand Up @@ -65,7 +67,7 @@ def smile2graph(id, smile, label, dataset):
g = None
if sanitized:
A,X,W = mol_to_matrices(mol, dataset)
g = GraphInstance(id=id,
g = MolecularInstance(id=id,
label=int(label),
data=A,
node_features=X,
Expand Down Expand Up @@ -149,4 +151,4 @@ def rdk_enum_type_to_map(enum_type,offset=0):
def rdk_enum_val_to_one_hot(enum_type):
one_hot_vec = np.zeros((1,len(enum_type.values)))
one_hot_vec[0,enum_type.real]=1
return one_hot_vec
return one_hot_vec
17 changes: 9 additions & 8 deletions src/dataset/instances/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from src.dataset.dataset_base import Dataset


class DataInstance:

def __init__(self, id, label, data, dataset=None):
def __init__(self, id, label, data, dataset: Optional["Dataset"] = None):
self.id = id
self.data = data
self.label = label #TODO: Refactoring to have a one-hot encoding of labels!
self.label = label # TODO: Refactoring to have a one-hot encoding of labels!
self._dataset = dataset





16 changes: 14 additions & 2 deletions src/dataset/instances/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,19 @@ def __init__(self, id, label, data, node_features=None, edge_features=None, edge
self.graph_features = graph_features
self._nx_repr = None

num_nodes = self.data.shape[0]
num_edges = np.count_nonzero(self.data)
assert len(self.node_features) == num_nodes
assert len(self.edge_features) == num_edges
assert len(self.edge_weights) == num_edges

def __deepcopy__(self, memo):
num_nodes = self.data.shape[0]
num_edges = np.count_nonzero(self.data)
assert len(self.node_features) == num_nodes
assert len(self.edge_features) == num_edges
assert len(self.edge_weights) == num_edges

# Fields that are being shallow copied
_dataset = self._dataset

Expand All @@ -29,7 +41,7 @@ def __deepcopy__(self, memo):
_edge_features = deepcopy(self.edge_features, memo)
_edge_weights = deepcopy(self.edge_weights, memo)
_graph_features = deepcopy(self.graph_features, memo)
return GraphInstance(_new_id, _new_label, _data, _node_features, _edge_features, _edge_weights, _graph_features)
return GraphInstance(_new_id, _new_label, _data, _node_features, _edge_features, _edge_weights, _graph_features, _dataset)

def get_nx(self):
if not self._nx_repr:
Expand Down Expand Up @@ -73,4 +85,4 @@ def degree(self,node):
return len(self.neighbors(node))

def degrees(self):
return [ len(self.neighbors(y)) for y in self.nodes()]
return [ len(self.neighbors(y)) for y in self.nodes()]
Loading

0 comments on commit 8276cd7

Please sign in to comment.