Skip to content

Commit

Permalink
Merge pull request molecularsets#80 from spoilt333/combinatorial
Browse files Browse the repository at this point in the history
Improved combinatorial generator: ligate only valid pairs of radicals
  • Loading branch information
danpol committed May 13, 2020
2 parents b75dd89 + ff02c2a commit 8289bb1
Show file tree
Hide file tree
Showing 19 changed files with 198 additions and 70 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ before_install:
- export PATH="$HOME/conda/bin:$PATH"
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda create -n test_env python=$TRAVIS_PYTHON_VERSION pip cmake
- conda create -c rdkit -n test_env python=$TRAVIS_PYTHON_VERSION pip cmake rdkit
- source activate test_env

install:
- conda install -q -c rdkit rdkit
- conda install -q -c rdkit rdkit==2019.09.3.0
- conda install -q flake8 scipy=1.2.0 pylint
- python setup.py install

Expand Down
60 changes: 30 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,28 +115,28 @@ Besides standard uniqueness and validity metrics, MOSES provides other metrics t
<td>0.9815±0.0012</td>
<td>0.5302±0.0163</td>
<td>0.0977±0.0142</td>
<td>0.8738±0.0002</td>
<td><b>0.8738±0.0002</b></td>
<td>0.8644±0.0002</td>
<td>0.9582±0.001</td>
<td>0.9694±0.001</td>
</tr>
<tr>
<td>Combinatorial</td>
<td>0.9979±0.0003</td>
<td>0.9983±0.0006</td>
<td>0.9948±0.0005</td>
<td>6.1626±0.0081</td>
<td>6.7734±0.0106</td>
<td>0.4226±0.0004</td>
<td>0.4079±0.0004</td>
<td>0.9151±0.0026</td>
<td>0.9099±0.0026</td>
<td>0.307±0.0187</td>
<td>0.0928±0.0079</td>
<td><b>0.8812±0.0003</b></td>
<td><b>0.8741±0.0003</b></td>
<td>0.7912±0.0021</td>
<td>0.9913±0.0004</td>
<td><b>1.0±0.0</b></td>
<td>0.9983±0.0015</td>
<td>0.9909±0.0009</td>
<td>4.2375±0.037</td>
<td>4.5113±0.0274</td>
<td>0.4514±0.0003</td>
<td>0.4388±0.0002</td>
<td>0.9912±0.0004</td>
<td>0.9904±0.0003</td>
<td>0.4445±0.0056</td>
<td>0.0865±0.0027</td>
<td>0.8732±0.0002</td>
<td><b>0.8666±0.0002</b></td>
<td>0.9557±0.0018</td>
<td>0.9878±0.0008</td>
</tr>
<tr>
<td>CharRNN</td>
Expand Down Expand Up @@ -212,21 +212,21 @@ Besides standard uniqueness and validity metrics, MOSES provides other metrics t
</tr>
<tr>
<td>LatentGAN</td>
<td>0.897±0.0024</td>
<td>0.8966±0.0029</td>
<td><b>1.0±0.0</b></td>
<td>0.997±0.0005</td>
<td>0.296±0.0214</td>
<td>0.8237±0.0295</td>
<td>0.5377±0.0013</td>
<td>0.5135±0.0009</td>
<td>0.9987±0.0003</td>
<td>0.9974±0.0003</td>
<td>0.8864±0.006</td>
<td>0.1004±0.0152</td>
<td>0.8565±0.0008</td>
<td>0.8504±0.0007</td>
<td>0.9727±0.001</td>
<td>0.9488±0.0014</td>
<td>0.9968±0.0002</td>
<td>0.2968±0.0087</td>
<td>0.8281±0.0117</td>
<td>0.5371±0.0004</td>
<td>0.5132±0.0002</td>
<td>0.9986±0.0004</td>
<td>0.9972±0.0007</td>
<td>0.8867±0.0009</td>
<td>0.1072±0.0098</td>
<td>0.8565±0.0007</td>
<td>0.8505±0.0006</td>
<td>0.9735±0.0006</td>
<td>0.9498±0.0006</td>
</tr>
</tbody>
</table>
Expand Down
4 changes: 2 additions & 2 deletions data/samples/combinatorial/combinatorial_1.csv
Git LFS file not shown
4 changes: 2 additions & 2 deletions data/samples/combinatorial/combinatorial_2.csv
Git LFS file not shown
4 changes: 2 additions & 2 deletions data/samples/combinatorial/combinatorial_3.csv
Git LFS file not shown
4 changes: 2 additions & 2 deletions data/samples/combinatorial/combinatorial_all.csv
Git LFS file not shown
2 changes: 1 addition & 1 deletion data/samples/combinatorial/metrics_combinatorial_1.csv
Git LFS file not shown
2 changes: 1 addition & 1 deletion data/samples/combinatorial/metrics_combinatorial_2.csv
Git LFS file not shown
2 changes: 1 addition & 1 deletion data/samples/combinatorial/metrics_combinatorial_3.csv
Git LFS file not shown
Binary file modified images/QED.pdf
Binary file not shown.
Binary file modified images/QED.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/SA.pdf
Binary file not shown.
Binary file modified images/SA.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/logP.pdf
Binary file not shown.
Binary file modified images/logP.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/weight.pdf
Binary file not shown.
Binary file modified images/weight.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
163 changes: 137 additions & 26 deletions moses/baselines/combinatorial.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
from collections import Counter
import pickle
import re
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from rdkit import Chem
from rdkit.Chem.BRICS import reactionDefs

import moses
from moses.metrics.utils import fragmenter
from moses.utils import mapper

isotope_re = re.compile(r'\[([0-9]+)[ab]?\*\]')

"""
Prepare binary connection rules
"""
binary_defs = {}
connection_rules = set()
for row in reactionDefs:
for a, b, t in row:
if a in ['7a', '7b']:
binary_defs[7] = 2**7
connection_rules.add((7, 7))
else:
a = int(a)
b = int(b)
binary_defs[a] = binary_defs.get(a, 0) | 2**b
binary_defs[b] = binary_defs.get(b, 0) | 2**a
connection_rules.add((a, b))
connection_rules.add((b, a))


class CombinatorialGenerator:
def __init__(self, n_jobs=1):
def __init__(self, n_jobs=1, mode=0):
"""
Combinatorial Generator randomly connects BRICS fragments
Arguments:
n_jobs: number of processes for training
mode: sampling mode
0: Sample fragment then connection
1: Sample connection point then fragments
"""
self.n_jobs = n_jobs
self.set_mode(mode)
self.fitted = False

def fit(self, data):
Expand All @@ -43,6 +70,10 @@ def fit(self, data):
fragment.count('*')
for fragment in counts['fragment'].values
]
counts['connection_rules'] = [
self.get_connection_rule(fragment)
for fragment in counts['fragment'].values
]
counts['frequency'] = counts['count'] / counts['count'].sum()
self.fragment_counts = counts

Expand All @@ -67,7 +98,9 @@ def save(self, path):
" Fit the model first")
data = {
'fragment_counts': self.fragment_counts,
'fragments_count_distribution': self.fragments_count_distribution
'fragments_count_distribution': self.fragments_count_distribution,
'n_jobs': self.n_jobs,
'mode': self.mode
}
with open(path, 'wb') as f:
pickle.dump(data, f)
Expand All @@ -89,23 +122,30 @@ def load(cls, path):
model.fragment_counts = data['fragment_counts']
model.fragments_count_distribution = \
data['fragments_count_distribution']
model.n_jobs = data['n_jobs']
model.mode = data['mode']
model.fitted = True
return model

def set_mode(self, mode):
if mode not in [0, 1]:
raise ValueError('Incorrect mode value: %s' % mode)
self.mode = mode

def generate_one(self, seed=None):
"""
Generates a SMILES string using fragment frequencies
Arguments:
seed: if specified, will set numpy seed before sampling
Retruns:
Returns:
SMILES string
"""
if seed is not None:
np.random.seed(seed)
if not self.fitted:
raise RuntimeError("Fit the model before generating")
if seed is not None:
np.random.seed(seed)
mol = None

# Sample the number of fragments
Expand All @@ -114,17 +154,11 @@ def generate_one(self, seed=None):
total_fragments = np.random.choice(count_values, p=count_probs)

counts = self.fragment_counts
current_attachments = 0
max_attachments = total_fragments - 1
connections_mol = None
for i in range(total_fragments):
# Enforce lower and upper limit on the number of connection points
if mol is None:
current_attachments = 0
max_attachments = total_fragments - 1
else:
connections_mol = self.get_connection_points(mol)
current_attachments = len(connections_mol)
max_attachments = (
total_fragments - i - current_attachments + 1
)
counts_masked = counts[
counts['attachment_points'] <= max_attachments
]
Expand All @@ -139,32 +173,109 @@ def generate_one(self, seed=None):
counts_masked['attachment_points'] >= min_attachments
]

# Sample a new fragment
new_fragment = counts_masked.sample(
weights=counts_masked['frequency']
)
new_fragment = dict(new_fragment.iloc[0])
fragment = Chem.MolFromSmiles(new_fragment['fragment'])
if mol is None:
mol = fragment
mol = self.sample_fragment(counts_masked)
else:
# Connect a new fragment to the molecule
connection_mol = np.random.choice(connections_mol)
if self.mode == 1: # Choose connection atom first
connections_mol = [np.random.choice(connections_mol)]

con_filter = self.get_connection_filter(connections_mol)
# Mask fragments with possible reactions
counts_masked = counts_masked[
counts_masked['connection_rules'] & con_filter > 0
]
fragment = self.sample_fragment(counts_masked)
connections_fragment = self.get_connection_points(fragment)
connection_fragment = np.random.choice(connections_fragment)
mol = self.connect_mols(mol, fragment,
connection_mol, connection_fragment)
possible_connections = self.filter_connections(
connections_mol,
connections_fragment
)

c_i = np.random.choice(len(possible_connections))
a1, a2 = possible_connections[c_i]

# Connect a new fragment to the molecule
mol = self.connect_mols(mol, fragment, a1, a2)

connections_mol = self.get_connection_points(mol)
current_attachments = len(connections_mol)
max_attachments = (
total_fragments - i - current_attachments
)
smiles = Chem.MolToSmiles(mol)
return smiles

def generate(self, n, seed=1, mode=0, verbose=False):
self.set_mode(mode)
seeds = range((seed - 1) * n, seed * n)
if verbose:
print('generating...')
seeds = tqdm(seeds, total=n)
samples = mapper(self.n_jobs)(self.generate_one, seeds)
return samples

def get_connection_rule(self, fragment):
"""
return OR combination for possible incoming reactions
Arguments:
fragment: fragment smiles
Returns:
int
"""
rule = 0
for i in map(int, set(isotope_re.findall(fragment))):
rule |= binary_defs[i]
return rule

@staticmethod
def sample_fragment(counts):
new_fragment = counts.sample(
weights=counts['frequency']
)
new_fragment = dict(new_fragment.iloc[0])
fragment = Chem.MolFromSmiles(new_fragment['fragment'])
return fragment

@staticmethod
def get_connection_filter(atoms):
"""
Return OR(2**isotopes)
"""
connection_rule = 0
for atom in atoms:
connection_rule |= 2**atom.GetIsotope()
return connection_rule

@staticmethod
def get_connection_points(mol):
"""
Return connection points
Arguments:
mol: ROMol
Returns:
atom list
"""
atoms = []
for atom in mol.GetAtoms():
if atom.GetSymbol() == '*':
atoms.append(atom)
return atoms

@staticmethod
def filter_connections(atoms1, atoms2, unique=True):
possible_connections = []
for a1 in atoms1:
i1 = a1.GetIsotope()
for a2 in atoms2:
i2 = a2.GetIsotope()
if (i1, i2) in connection_rules:
possible_connections.append((a1, a2))
return possible_connections

@staticmethod
def connect_mols(mol1, mol2, atom1, atom2):
combined = Chem.CombineMols(mol1, mol2)
Expand Down
Loading

0 comments on commit 8289bb1

Please sign in to comment.