Skip to content

Commit

Permalink
Merge pull request #4 from NadineSchneider/master
Browse files Browse the repository at this point in the history
New code and examples related to a new book chapter
  • Loading branch information
NadineSchneider authored Apr 12, 2021
2 parents c64be56 + f920027 commit 67f6177
Show file tree
Hide file tree
Showing 15 changed files with 131,423 additions and 47 deletions.
54 changes: 38 additions & 16 deletions ChemTopicModel/chemTopicModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _generateFragmentVocabulary(self,molSample):
self.vocabulary = sorted(n for n,i in zip(keys,normFragOcc) if i != 0)
self.fragIdx=dict((i,j) for j,i in enumerate(self.vocabulary))
if self.verbose:
print('Created vocabulary, size: {0}, used sample size: {1}'.format(len(self.vocabulary),len(molSample)))
print('Created alphabet, size: {0}, used sample size: {1}'.format(len(self.alphabet),len(molSample)))

# generate the fragment templates important for the visualisation of the topics later
def _generateFragmentTemplates(self,molSample):
Expand Down Expand Up @@ -242,44 +242,65 @@ def generateFragments(self):
self._generateFragmentMatrix()

# it is better use these functions instead of buildTopicModel if the dataset is larger
def fitTopicModel(self, numTopics, max_iter=100, **kwargs):
def fitTopicModel(self, numTopics, max_iter=100, nJobs=1, sizeFittingDataset=1.0, **kwargs):

self.lda = LatentDirichletAllocation(n_topics=numTopics,learning_method=self.learningMethod,random_state=self.seed,
n_jobs=1, max_iter=max_iter, batch_size=self.chunksize, **kwargs)
if self.fragM.shape[0] > self.chunksize:
self.lda = LatentDirichletAllocation(n_components=numTopics,learning_method=self.learningMethod,random_state=self.seed,
n_jobs=nJobs, max_iter=max_iter, batch_size=self.chunksize, **kwargs)

inputMatrix=self.fragM
if sizeFittingDataset < 1.0:

np.random.seed(self.seed)
upperIndex = self.fragM.shape[0]-1
size = int(self.fragM.shape[0]*sizeFittingDataset)
ids = np.random.randint(0,upperIndex, size=size)
inputMatrix = self.fragM[sorted(ids)]

if inputMatrix.shape[0] > self.chunksize:
# fit the model in chunks
self.lda.learning_method = 'online'
self.lda.fit(self.fragM)
else:
self.lda.fit(self.fragM)
self.lda.fit(inputMatrix)

def transformDataToTopicModel(self):
def transformDataToTopicModel(self,lowerPrecision=False):

try:
self.lda
except:
raise ValueError('No topic model is available')

if lowerPrecision:
print('WARNING: using lower precision mode')

if self.fragM.shape[0] > self.chunksize:
# after fitting transform the data to our model
for chunk in self._generateMatrixChunks(0,self.fragM.shape[0],chunksize=self.chunksize):
resultLDA = self.lda.transform(chunk[0])
# here using a 16bit float instead of the 64bit float would save memory and might be enough precision. Test that later!!
# here using a 32bit float instead of the 64bit float would save memory and might be enough precision. Test that later!!
if chunk[1] > 0:
self.documentTopicProbabilities = np.concatenate((self.documentTopicProbabilities,
if lowerPrecision:
self.documentTopicProbabilities = np.concatenate((self.documentTopicProbabilities,
(resultLDA/resultLDA.sum(axis=1,keepdims=1)).astype(np.float32)), axis=0)
else:
self.documentTopicProbabilities = np.concatenate((self.documentTopicProbabilities,
resultLDA/resultLDA.sum(axis=1,keepdims=1)), axis=0)
else:
self.documentTopicProbabilities = resultLDA/resultLDA.sum(axis=1,keepdims=1)
if lowerPrecision:
self.documentTopicProbabilities = self.documentTopicProbabilities.astype(np.float32)
else:
resultLDA = self.lda.transform(self.fragM)
self.documentTopicProbabilities = resultLDA/resultLDA.sum(axis=1,keepdims=1)
# next line is not need anymore since it is normalized in sklearn already since version 0.18
# self.documentTopicProbabilities = resultLDA/resultLDA.sum(axis=1,keepdims=1)
self.documentTopicProbabilities = resultLDA
if lowerPrecision:
self.documentTopicProbabilities = self.documentTopicProbabilities.astype(np.float32)


# use this if the dataset is small- to medium-sized
def buildTopicModel(self, numTopics, max_iter=100, **kwargs):
def buildTopicModel(self, numTopics, max_iter=100, nJobs=1, lowerPrecision=False, sizeFittingDataset=0.1, **kwargs):

self.fitTopicModel(numTopics, max_iter=max_iter, **kwargs)
self.transformDataToTopicModel()
self.fitTopicModel(numTopics, max_iter=max_iter, nJobs=nJobs, sizeFittingDataset=sizeFittingDataset, **kwargs)
self.transformDataToTopicModel(lowerPrecision=lowerPrecision)


def getTopicFragmentProbabilities(self):
Expand All @@ -289,4 +310,5 @@ def getTopicFragmentProbabilities(self):
except:
raise ValueError('No topic model is available')
return self.lda.components_/self.lda.components_.sum(axis=1,keepdims=1)



6 changes: 3 additions & 3 deletions ChemTopicModel/drawFPBits.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import numpy as np
import re

def _drawFPBit(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,svg=True,**kwargs):
def _drawFPBit(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,svg=True, fontSize=0.9,**kwargs):
mol = Chem.MolFromSmiles(smi)
rdDepictor.Compute2DCoords(mol)

Expand Down Expand Up @@ -82,7 +82,7 @@ def _drawFPBit(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,svg=True
else:
drawer = rdMolDraw2D.MolDraw2DCairo(molSize[0],molSize[1])


drawer.SetFontSize(fontSize)
drawopt=drawer.drawOptions()
drawopt.continuousHighlight=False

Expand Down Expand Up @@ -118,7 +118,7 @@ def _drawFPBit(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,svg=True
def drawFPBitPNG(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,**kwargs):
return _drawFPBit(smi,bitPath,molSize=molSize,kekulize=kekulize,baseRad=baseRad, svg=False,**kwargs)

def drawFPBit(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,**kwargs):
def drawFPBit(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,fontSize=0.9,**kwargs):
svg = _drawFPBit(smi,bitPath,molSize=molSize,kekulize=kekulize,baseRad=baseRad,**kwargs)
return svg.replace('svg:','')

Expand Down
88 changes: 61 additions & 27 deletions ChemTopicModel/drawTopicModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from rdkit import Chem
from rdkit.Chem import rdqueries
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import rdDepictor

from IPython.display import display,HTML,SVG

from collections import defaultdict
Expand Down Expand Up @@ -97,7 +99,7 @@ def _getAtomWeights(mol, molID, topicID, topicModel):

# hightlight a topic in a molecule
def drawTopicWeightsMolecule(mol, molID, topicID, topicModel, molSize=(450,200), kekulize=True,\
baseRad=0.1, color=(.9,.9,.9)):
baseRad=0.1, color=(.9,.9,.9), fontSize=0.9):

# get the atom weights
atomWeights,maxWeightTopic=_getAtomWeights(mol, molID, topicID, topicModel)
Expand All @@ -112,9 +114,14 @@ def drawTopicWeightsMolecule(mol, molID, topicID, topicModel, molSize=(450,200),

if atRads[at] > 0 and atRads[at] < 0.2:
atRads[at] = 0.2

try:
mc = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=kekulize)
except ValueError: # <- can happen on a kekulization failure
mc = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=False)

mc = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=kekulize)
drawer = rdMolDraw2D.MolDraw2DSVG(molSize[0],molSize[1])
drawer.SetFontSize(fontSize)
drawer.DrawMolecule(mc,highlightAtoms=atColors.keys(),
highlightAtomColors=atColors,highlightAtomRadii=atRads,
highlightBonds=[])
Expand All @@ -124,28 +131,34 @@ def drawTopicWeightsMolecule(mol, molID, topicID, topicModel, molSize=(450,200),

# generates all svgs of molecules belonging to a certain topic and highlights this topic within the molecule
def generateMoleculeSVGsbyTopicIdx(topicModel, topicIdx, idsLabelToShow=[0], topicProbThreshold = 0.5, baseRad=0.5,\
molSize=(250,150),color=(.0,.0, 1.),maxMols=100):
molSize=(250,150),color=(.0,.0, 1.),maxMols=100, fontSize=0.9, maxTopicProb=0.5):
svgs=[]
namesSVGs=[]
numDocs, numTopics = topicModel.documentTopicProbabilities.shape

if topicIdx >= numTopics:
return "Topic not found"
molset = topicModel.documentTopicProbabilities[:,topicIdx].argsort()[::-1][:maxMols]
for doc in molset:
if topicModel.documentTopicProbabilities[doc,topicIdx] >= topicProbThreshold:
data = topicModel.moldata.iloc[doc]
smi = data['smiles']
name = ''
for idx in idsLabelToShow:
name += str(data['label_'+str(idx)])
name += ' | '
mol = Chem.MolFromSmiles(smi)
topicProb = topicModel.documentTopicProbabilities[doc,topicIdx]
svg = drawTopicWeightsMolecule(mol, doc, topicIdx, topicModel, molSize=molSize, baseRad=baseRad, color=color)
svgs.append(svg)
maxTopicID= np.argmax(topicModel.documentTopicProbabilities[doc,:])
namesSVGs.append(str(name)+"(p="+str(round(topicProb,2))+")")
tmp=topicModel.documentTopicProbabilities[:,topicIdx]
ids=np.where(tmp >= topicProbThreshold)
molset = sorted(list(zip(tmp[ids].tolist(),ids[0].tolist())), reverse=True)[:maxMols]
if maxTopicProb > topicProbThreshold:
ids=np.where((tmp >= topicProbThreshold) & (tmp < maxTopicProb))
molset = sorted(list(zip(tmp[ids].tolist(),ids[0].tolist())), reverse=False)[:maxMols]

for prob,doc in molset:
data = topicModel.moldata.iloc[doc]
smi = data['smiles']
name = ''
for idx in idsLabelToShow:
name += str(data['label_'+str(idx)])
name += ' | '
mol = Chem.MolFromSmiles(smi)
topicProb = prob #topicModel.documentTopicProbabilities[doc,topicIdx]
svg = drawTopicWeightsMolecule(mol, doc, topicIdx, topicModel, molSize=molSize, baseRad=baseRad, color=color, fontSize=fontSize)
svgs.append(svg)
maxTopicID= np.argmax(topicModel.documentTopicProbabilities[doc])
maxProb = np.max(topicModel.documentTopicProbabilities[doc])
namesSVGs.append('{0}(p={1:.2f}) | (pmax({2})={3:.2f})'.format(name,topicProb,maxTopicID,maxProb))
if not len(svgs):
#print('No molecules can be drawn')
return [],[]
Expand Down Expand Up @@ -220,10 +233,10 @@ def generateSVGGridMolsByLabel(topicModel, label, idLabelToMatch=0, baseRad=0.5,

# draws molecules belonging to a certain topic in a html table and highlights this topic within the molecules
def drawMolsByTopic(topicModel, topicIdx, idsLabelToShow=[0], topicProbThreshold = 0.5, baseRad=0.5, molSize=(250,150),\
numRowsShown=3, color=(.0,.0, 1.), maxMols=100):
numRowsShown=3, color=(.0,.0, 1.), maxMols=100, fontSize=0.9,maxTopicProb=0.5):
result = generateMoleculeSVGsbyTopicIdx(topicModel, topicIdx, idsLabelToShow=idsLabelToShow, \
topicProbThreshold = topicProbThreshold, baseRad=baseRad,\
molSize=molSize,color=color, maxMols=maxMols)
molSize=molSize,color=color, maxMols=maxMols,fontSize=fontSize,maxTopicProb=maxTopicProb)
if len(result) == 1:
print(result)
return
Expand All @@ -242,11 +255,11 @@ def drawMolsByTopic(topicModel, topicIdx, idsLabelToShow=[0], topicProbThreshold

# produces a svg grid of the molecules belonging to a certain topic and highlights this topic within the molecules
def generateSVGGridMolsbyTopic(topicModel, topicIdx, idsLabelToShow=[0], topicProbThreshold = 0.5, baseRad=0.5, \
molSize=(250,150), svgsPerRow=4, color=(1.,1.,1.)):
molSize=(250,150), svgsPerRow=4, color=(1.,1.,1.), maxMols=100, fontSize=0.9, maxTopicProb=0.5):

result = generateMoleculeSVGsbyTopicIdx(topicModel, topicIdx, idsLabelToShow=idsLabelToShow, \
topicProbThreshold = topicProbThreshold, baseRad=baseRad,\
molSize=molSize,color=color)
molSize=molSize,color=color, maxMols=maxMols, fontSize=fontSize, maxTopicProb=maxTopicProb)
if len(result) == 1:
print(result)
return
Expand All @@ -259,9 +272,24 @@ def generateSVGGridMolsbyTopic(topicModel, topicIdx, idsLabelToShow=[0], topicPr

####### Fragments ##########################


# produces a svg grid of the fragemnts belonging to a certain topic
def generateSVGGridFragemntsForTopic(topicModel, topicIdx, n_top_frags=10, molSize=(100,100),\
svg=True, prior=-1.0, fontSize=0.9,svgsPerRow=4):

svgs = generateTopicRelatedFragmentSVGs(topicModel, topicIdx, n_top_frags=n_top_frags, molSize=molSize,\
svg=svg, prior=prior, fontSize=fontSize)
scores = topicModel.getTopicFragmentProbabilities()
namesSVGs = list(map(lambda x: "p(k={0})={1:.2f}".format(topicIdx,x), \
filter(lambda y: y > prior, sorted(scores[topicIdx,:], reverse=True)[:n_top_frags])))

svgGrid = utilsDrawing.SvgsToGrid(svgs, namesSVGs, svgsPerRow=svgsPerRow, molSize=molSize)

return svgGrid

# generates svgs of the fragments related to a certain topic
def generateTopicRelatedFragmentSVGs(topicModel, topicIdx, n_top_frags=10, molSize=(100,100),\
svg=True, prior=-1.0):
svg=True, prior=-1.0, fontSize=0.9):
svgs=[]
probs = topicModel.getTopicFragmentProbabilities()
numTopics, numFragments = probs.shape
Expand All @@ -271,12 +299,18 @@ def generateTopicRelatedFragmentSVGs(topicModel, topicIdx, n_top_frags=10, molSi
for i in probs[topicIdx,:].argsort()[::-1][:n_top_frags]:
if probs[topicIdx,i] > prior:
bit = topicModel.vocabulary[i]

# allows including words
if type(bit) != int:
svgs.append(bit)
continue

# draw the bits using the templates
if topicModel.fragmentMethod in ['Morgan', 'RDK']:
templMol = topicModel.fragmentTemplates.loc[topicModel.fragmentTemplates['bitIdx'] == bit]['templateMol'].item()
pathTemplMol = topicModel.fragmentTemplates.loc[topicModel.fragmentTemplates['bitIdx'] == bit]['bitPathTemplateMol'].item()
if svg:
svgs.append(drawFPBits.drawFPBit(templMol,pathTemplMol,molSize=molSize))
svgs.append(drawFPBits.drawFPBit(templMol,pathTemplMol,molSize=molSize, fontSize=fontSize))
else:
svgs.append(drawFPBits.drawFPBitPNG(templMol,pathTemplMol,molSize=molSize))
else:
Expand All @@ -288,14 +322,14 @@ def generateTopicRelatedFragmentSVGs(topicModel, topicIdx, n_top_frags=10, molSi

# draw the svgs of the fragments related to a certain topic in a html table
def drawFragmentsbyTopic(topicModel, topicIdx, n_top_frags=10, numRowsShown=4, cssTableName='fragTab', \
prior=-1.0, numColumns=4, tableHeader=''):
prior=-1.0, numColumns=4, tableHeader='',fontSize=0.9):

scores = topicModel.getTopicFragmentProbabilities()
numTopics, numFragments = scores.shape
if prior < 0:
prior = 1./numFragments
svgs=generateTopicRelatedFragmentSVGs(topicModel, topicIdx, n_top_frags=n_top_frags, prior=prior)
namesSVGs = list(map(lambda x: "Score %.2f" % x, \
svgs=generateTopicRelatedFragmentSVGs(topicModel, topicIdx, n_top_frags=n_top_frags, prior=prior,fontSize=fontSize)
namesSVGs = list(map(lambda x: "p(k={0})={1:.2f}".format(topicIdx,x), \
filter(lambda y: y > prior, sorted(scores[topicIdx,:], reverse=True)[:n_top_frags])))
if tableHeader == '':
tableHeader = "Topic "+str(topicIdx)
Expand Down
Loading

0 comments on commit 67f6177

Please sign in to comment.