Skip to content

Commit

Permalink
combinatorial generator upated to do only valid fragment combination
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur Kadurin committed Apr 23, 2020
1 parent b75dd89 commit 78cf286
Showing 1 changed file with 148 additions and 26 deletions.
174 changes: 148 additions & 26 deletions moses/baselines/combinatorial.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
from collections import Counter
import pickle
import re
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from rdkit import Chem
from rdkit.Chem.BRICS import reactionDefs

import moses
from moses.metrics.utils import fragmenter
from moses.utils import mapper

isotope_re = re.compile(r'\[([0-9]+)[ab]?\*\]')

"""
Prepare binary connection rules
"""
binary_defs = {}
connection_rules = set()
for row in reactionDefs:
for a, b, t in row:
if a in ['7a', '7b']:
binary_defs[7] = 2**7
connection_rules.add((7, 7))
else:
a = int(a)
b = int(b)
binary_defs[a] = binary_defs.get(a, 0) | 2**b
binary_defs[b] = binary_defs.get(b, 0) | 2**a
connection_rules.add((a, b))
connection_rules.add((b, a))


class CombinatorialGenerator:
def __init__(self, n_jobs=1):
def __init__(self, n_jobs=1, mode=0):
"""
Combinatorial Generator randomly connects BRICS fragments
Arguments:
n_jobs: number of processes for training
mode: sampling mode
last bit sets sampling connection point or fragment first
second bit sets sampling connection between two fragments
0: Sample connection point, sample from unique reactions
1: Sample fragment first, sample from unique reactions
2: Sample connection point, sample from all possible reactions
3: Sample fragment first, sample from all possible reactions
"""
self.n_jobs = n_jobs
self.mode = mode
self.fitted = False

def fit(self, data):
Expand All @@ -43,6 +74,10 @@ def fit(self, data):
fragment.count('*')
for fragment in counts['fragment'].values
]
counts['connection_rules'] = [
self.get_connection_rule(fragment)
for fragment in counts['fragment'].values
]
counts['frequency'] = counts['count'] / counts['count'].sum()
self.fragment_counts = counts

Expand All @@ -67,7 +102,9 @@ def save(self, path):
" Fit the model first")
data = {
'fragment_counts': self.fragment_counts,
'fragments_count_distribution': self.fragments_count_distribution
'fragments_count_distribution': self.fragments_count_distribution,
'n_jobs': self.n_jobs,
'mode': self.mode
}
with open(path, 'wb') as f:
pickle.dump(data, f)
Expand All @@ -89,23 +126,30 @@ def load(cls, path):
model.fragment_counts = data['fragment_counts']
model.fragments_count_distribution = \
data['fragments_count_distribution']
model.n_jobs = data['n_jobs']
model.mode = data['mode']
model.fitted = True
return model

def set_mode(self, mode):
if mode not in [0, 1, 2, 3]:
raise ValueError('Incorrect mode value: %s' % mode)
self.mode = mode

def generate_one(self, seed=None):
"""
Generates a SMILES string using fragment frequencies
Arguments:
seed: if specified, will set numpy seed before sampling
Retruns:
Returns:
SMILES string
"""
if seed is not None:
np.random.seed(seed)
if not self.fitted:
raise RuntimeError("Fit the model before generating")
if seed is not None:
np.random.seed(seed)
mol = None

# Sample the number of fragments
Expand All @@ -114,17 +158,11 @@ def generate_one(self, seed=None):
total_fragments = np.random.choice(count_values, p=count_probs)

counts = self.fragment_counts
current_attachments = 0
max_attachments = total_fragments - 1
connections_mol = None
for i in range(total_fragments):
# Enforce lower and upper limit on the number of connection points
if mol is None:
current_attachments = 0
max_attachments = total_fragments - 1
else:
connections_mol = self.get_connection_points(mol)
current_attachments = len(connections_mol)
max_attachments = (
total_fragments - i - current_attachments + 1
)
counts_masked = counts[
counts['attachment_points'] <= max_attachments
]
Expand All @@ -139,32 +177,116 @@ def generate_one(self, seed=None):
counts_masked['attachment_points'] >= min_attachments
]

# Sample a new fragment
new_fragment = counts_masked.sample(
weights=counts_masked['frequency']
)
new_fragment = dict(new_fragment.iloc[0])
fragment = Chem.MolFromSmiles(new_fragment['fragment'])
if mol is None:
mol = fragment
mol = self.sample_fragment(counts_masked)
else:
# Connect a new fragment to the molecule
connection_mol = np.random.choice(connections_mol)
if self.mode & 1: # Sample fragment first
con_filter = self.get_connection_filter(connections_mol)
else: # Choose connection atom first
atom_mol = np.random.choice(connections_mol)
connections_mol = [atom_mol]
con_filter = 2**atom_mol.GetIsotope()

# Mask fragments with possible reactions
counts_masked = counts_masked[
counts_masked['connection_rules'] & con_filter > 0
]
fragment = self.sample_fragment(counts_masked)
connections_fragment = self.get_connection_points(fragment)
connection_fragment = np.random.choice(connections_fragment)
mol = self.connect_mols(mol, fragment,
connection_mol, connection_fragment)
possible_connections = self.filter_connections(
connections_mol,
connections_fragment
)

if self.mode & 2: # Sample weighted connection
c_i = np.random.choice(len(possible_connections))
a1, a2 = possible_connections[c_i]
else: # Sample from unique connections
possible_connections = list(set(possible_connections))
c_i = np.random.choice(len(possible_connections))
a1, a2 = possible_connections[c_i]

# Connect a new fragment to the molecule
mol = self.connect_mols(mol, fragment, a1, a2)

connections_mol = self.get_connection_points(mol)
current_attachments = len(connections_mol)
max_attachments = (
total_fragments - i - current_attachments
)
smiles = Chem.MolToSmiles(mol)
return smiles

def generate(self, n, seed, mode=0):
self.set_mode(mode)
generator = (self.generate_one(seed) for i in range(n))
if self.verbose:
print('generating...')
generator = tqdm(generator, total=n)
return list(generator)

def get_connection_rule(self, fragment):
"""
return OR combination for possible incoming reactions
Arguments:
fragment: fragment smiles
Returns:
int
"""
rule = 0
for i in map(int, set(isotope_re.findall(fragment))):
rule |= binary_defs[i]
return rule

@staticmethod
def sample_fragment(counts):
new_fragment = counts.sample(
weights=counts['frequency']
)
new_fragment = dict(new_fragment.iloc[0])
fragment = Chem.MolFromSmiles(new_fragment['fragment'])
return fragment

@staticmethod
def get_connection_filter(atoms):
"""
Return OR(2**isotopes)
"""
connection_rule = 0
for atom in atoms:
connection_rule |= 2**atom.GetIsotope()
return connection_rule

@staticmethod
def get_connection_points(mol):
"""
Return connection points
Arguments:
mol: ROMol
Returns:
atom list
"""
atoms = []
for atom in mol.GetAtoms():
if atom.GetSymbol() == '*':
atoms.append(atom)
return atoms

@staticmethod
def filter_connections(atoms1, atoms2):
possible_connections = []
for a1 in atoms1:
i1 = a1.GetIsotope()
for a2 in atoms2:
i2 = a2.GetIsotope()
if (i1, i2) in connection_rules:
possible_connections.append((a1, a2))
return possible_connections

@staticmethod
def connect_mols(mol1, mol2, atom1, atom2):
combined = Chem.CombineMols(mol1, mol2)
Expand Down

0 comments on commit 78cf286

Please sign in to comment.