Skip to content

Commit

Permalink
Merge pull request #198 from datamol-io/lasso_fixes
Browse files Browse the repository at this point in the history
Fix bugs and support atom indices as inputs in the lasso viz function
  • Loading branch information
hadim authored Jun 16, 2023
2 parents 55c190c + c6c656b commit f5a7584
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 64 deletions.
6 changes: 3 additions & 3 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

_Checklist:_

- [x] _Was this PR discussed in an issue? It is recommended to first discuss a new feature into a GitHub issue before opening a PR._
- [ ] _Was this PR discussed in an issue? It is recommended to first discuss a new feature into a GitHub issue before opening a PR._
- [ ] _Add tests to cover the fixed bug(s) or the new introduced feature(s) (if appropriate)._
- [ ] _Update the API documentation is a new function is added, or an existing one is deleted._
- [x] _Write concise and explanatory changelogs below._
- [x] _If possible, assign one of the following labels to the PR: `feature`, `fix` or `test` (or ask a maintainer to do it for you)._
- [ ] _Write concise and explanatory changelogs below._
- [ ] _If possible, assign one of the following labels to the PR: `feature`, `fix` or `test` (or ask a maintainer to do it for you)._

---

Expand Down
96 changes: 50 additions & 46 deletions datamol/viz/_lasso_highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# - possibility to do this for multiple target molecules at once
# - have the option to write to a file like to_image

from typing import List, Iterator, Tuple, Union, Optional, Any
from typing import List, Iterator, Tuple, Union, Optional, Any, cast

from collections import defaultdict
from collections import namedtuple
Expand Down Expand Up @@ -362,7 +362,8 @@ def _draw_multi_matches(

def lasso_highlight_image(
target_molecule: Union[str, dm.Mol],
search_molecules: Union[str, List[str], dm.Mol, List[dm.Mol]],
search_molecules: Union[str, List[str], dm.Mol, List[dm.Mol]] = None,
atom_indices: Optional[Union[List[int], List[List[int]]]] = None,
mol_size: Tuple[int, int] = (300, 300),
use_svg: Optional[bool] = True,
r_min: float = 0.3,
Expand All @@ -372,11 +373,13 @@ def lasso_highlight_image(
line_width: int = 2,
**kwargs: Any,
):
"""Create an image of a molecule with substructure matches using lasso-based highlighting.
"""Create an image of a molecule with substructure matches using lasso-based highlighting. Substructure matching is
optional and it's also possible to pass a list of list of atom indices to highlight.
args:
target_molecule: The molecule to be highlighted
search_molecules: The substructure to be identified
search_molecules: The substructure to be highlighted.
atom_indices: Atom indices to be highlighted substructure.
mol_size: The size of the image to be returned
use_svg: Whether to return an svg or png image
r_min: Radius of the smallest circle around atoms. Length is relative to average bond length (1 = avg. bond len).
Expand All @@ -389,15 +392,15 @@ def lasso_highlight_image(
https://www.rdkit.org/docs/source/rdkit.Chem.Draw.rdMolDraw2D.html.
"""

if search_molecules is None:
search_molecules = []

## Step 0: Input validation

# check if the input is valid
if target_molecule is None or (isinstance(target_molecule, str) and len(target_molecule) == 0):
raise ValueError("Please enter a valid target molecule or smiles")

if search_molecules is None or (
isinstance(search_molecules, str) and len(search_molecules) == 0
):
raise ValueError("Please enter valid search molecules or smarts")

# less than 1 throws File parsing error: PNG header not recognized over 5,000 leads to a DecompressionBombError later on
if mol_size[0] < 1 or mol_size[0] > 5000 or mol_size[1] < 1 or mol_size[1] > 5000:
raise ValueError(
Expand All @@ -407,6 +410,43 @@ def lasso_highlight_image(
if isinstance(target_molecule, str):
target_molecule = dm.to_mol(target_molecule)

if target_molecule is None:
raise ValueError("Please enter a valid target molecule or smiles")

# Always make the type checker happy
target_mol = cast(dm.Mol, target_molecule)

## Step 1: Match the search molecules or SMARTS to the target molecule

# Make `search_molecules` a list if it is not already
if not isinstance(search_molecules, (list, tuple)):
search_molecules = [search_molecules]

atom_idx_list = []
for search_mol in search_molecules:
if isinstance(search_mol, str):
search_mol = dm.from_smarts(search_mol)

if search_mol is None or not isinstance(search_mol, dm.Mol):
raise ValueError(f"Please enter valid search molecules or smarts: {search_mol}")

matches = target_mol.GetSubstructMatches(search_mol)
if not matches:
logger.warning(f"No matching substructures found for {dm.to_smarts(search_mol)}")
else:
matched_atoms = set.union(*[set(x) for x in matches])
atom_idx_list.append(matched_atoms)

## Step 2: add the atom indices to the list if any
if atom_indices is not None:
if not isinstance(atom_indices[0], (list, tuple)):
atom_indices_list_of_list = [atom_indices]
else:
atom_indices_list_of_list = atom_indices
atom_idx_list += atom_indices_list_of_list

## Step 3: Prepare the molecule for drawing and draw it

mol = prepare_mol_for_drawing(target_molecule, kekulize=True)

if mol is None:
Expand All @@ -431,43 +471,7 @@ def lasso_highlight_image(
drawer.DrawMolecule(mol)
drawer.ClearDrawing()

# get the atom indices for the search molecules
atom_idx_list = []
if isinstance(search_molecules, str):
smart_obj = dm.to_mol(search_molecules)
matches = mol.GetSubstructMatches(smart_obj)
if not matches:
logger.warning(f"no matching substructure found for {search_molecules}")
else:
matched_atoms = set.union(*[set(x) for x in matches])
atom_idx_list.append(matched_atoms)

elif isinstance(search_molecules, dm.Mol):
matches = mol.GetSubstructMatches(search_molecules)
if not matches:
logger.warning(f"no matching substructure found for {dm.to_smiles(search_molecules)}")
else:
matched_atoms = set.union(*[set(x) for x in matches])
atom_idx_list.append(matched_atoms)

elif len(search_molecules) and isinstance(search_molecules[0], str):
for smart_str in search_molecules:
smart_obj = dm.to_mol(smart_str)
matches = mol.GetSubstructMatches(smart_obj)
if not matches:
logger.warning(f"no matching substructure found for {smart_str}")
else:
matched_atoms = set.union(*[set(x) for x in matches])
atom_idx_list.append(matched_atoms)

elif len(search_molecules) and isinstance(search_molecules[0], dm.Mol):
for smart_obj in search_molecules:
matches = mol.GetSubstructMatches(smart_obj)
if not matches:
logger.warning(f"no matching substructure found for {dm.to_smiles(smart_obj)}")
else:
matched_atoms = set.union(*[set(x) for x in matches])
atom_idx_list.append(matched_atoms)
## Step 4: Draw the matches

if color_list is None:
color_list = DEFAULT_LASSO_COLORS
Expand Down
52 changes: 37 additions & 15 deletions tests/test_viz_lasso_highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,7 @@ def test_canvas_input_error():
with pytest.raises(ValueError):
smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]"
smarts_list = ["CONN", "N#CC~CO", "C=CON", "CONNCN"]
dm.lasso_highlight_image(smi, smarts_list, (5001, 400))


def test_search_input_error_empty_str():
with pytest.raises(ValueError):
smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]"
smarts_list = ""
dm.lasso_highlight_image(smi, smarts_list)
dm.lasso_highlight_image(smi, smarts_list, mol_size=(5001, 400))


def test_search_input_error_empty_list():
Expand All @@ -96,13 +89,6 @@ def test_search_input_error_empty_list():
assert dm.lasso_highlight_image(smi, smarts_list)


def test_search_input_error_None():
with pytest.raises(ValueError):
smi = "CO[C@@H](O)C1=C(O[C@H](F)Cl)C(C#N)=C1ONNC[NH3+]"
smarts_list = None
dm.lasso_highlight_image(smi, smarts_list)


def test_target_input_error_empty_str():
with pytest.raises(ValueError):
smi = ""
Expand Down Expand Up @@ -148,3 +134,39 @@ def test_PNG_is_returned():
from PIL import Image

assert isinstance(img, Image.Image)


def test_aromatic_query_work():
smi = "CC(N)Cc1c[nH]c2ccc3c(c12)CCCO3"
smarts_list = ["c1ccccc1"]
assert dm.lasso_highlight_image(smi, smarts_list)


def test_smarts_query():
smi = "CC(N)Cc1c[nH]c2ccc3c(c12)CCCO3"
smarts_list = "[#6]"
assert dm.lasso_highlight_image(smi, smarts_list)


def test_query_and_atom_indices_list():
dm.viz.lasso_highlight_image(
"CC(N)Cc1c[nH]c2ccc3c(c12)CCCO3",
search_molecules="c1ccccc1",
atom_indices=[[4, 5, 6], [1, 2, 3, 4]],
)


def test_atom_indices_list_of_list():
dm.viz.lasso_highlight_image(
"CC(N)Cc1c[nH]c2ccc3c(c12)CCCO3",
search_molecules=None,
atom_indices=[[4, 5, 6], [1, 2, 3, 4]],
)


def test_atom_indices_list():
dm.viz.lasso_highlight_image(
"CC(N)Cc1c[nH]c2ccc3c(c12)CCCO3",
search_molecules=None,
atom_indices=[4, 5, 6],
)

0 comments on commit f5a7584

Please sign in to comment.