Skip to content

Commit

Permalink
addressed comments, added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
csbrasnett committed Apr 23, 2024
1 parent 964050a commit 52836f4
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 12 deletions.
7 changes: 4 additions & 3 deletions bin/martinize2
Original file line number Diff line number Diff line change
Expand Up @@ -969,13 +969,14 @@ def entry():
itp_paths = []
# Merge chains if required.
if args.merge_chains:
if args.merge_chains[0] != "all":
if "all" not in args.merge_chains:
input_chain_sets = [i.split(",") for i in args.merge_chains]
for chain_set in input_chain_sets:
vermouth.MergeChains(chains=chain_set, all_chains=False).run_system(system)
else:
elif "all" in args.merge_chains and len(args.merge_chains) == 1:
vermouth.MergeChains(chains=None, all_chains=True).run_system(system)

else:
LOGGER.warning("Multiple conflicting merging arguments given. Please check input arguments.")
vermouth.NameMolType(deduplicate=not args.keep_duplicate_itp).run_system(system)
defines = ()

Expand Down
30 changes: 21 additions & 9 deletions vermouth/processors/merge_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from ..molecule import Molecule
from ..processors.processor import Processor
from ..log_helpers import StyleAdapter, get_logger
LOGGER = StyleAdapter(get_logger(__name__))


def merge_chains(system, chains, all_chains):
Expand All @@ -42,22 +44,32 @@ def merge_chains(system, chains, all_chains):
The system to modify.
chains: list[str]
A container of chain identifier.
all_chains:
all_chains: bool
If True, all chains will be merged.
"""
if not all_chains:
chains = set(chains)
else:
l = []
if not all_chains and len(chains)>0:
_chains = set(chains)
elif all_chains and chains is None:
_chains = set()
for molecule in system.molecules:
l.append([node.get('chain') for node in molecule.nodes.values()][0])
chains = set(l)
# Molecules can contain multiple chains
_chains.update(node.get('chain') for node in molecule.nodes.values())
else:
LOGGER.warning('Conflicting merging arguments have been given. Will abort merging.')
return

try:
assert ''.join(list(_chains)).isalnum()
except TypeError:
LOGGER.warning('One or more of your chains does not have a chain identifier in input file.')

merged = Molecule()
merged._force_field = system.force_field
has_merged = False
new_molecules = []
for molecule in system.molecules:
molecule_chains = set(node.get('chain') for node in molecule.nodes.values())
if molecule_chains.issubset(chains):
if molecule_chains.issubset(_chains):
if not has_merged:
merged.nrexcl = molecule.nrexcl
new_molecules.append(merged)
Expand All @@ -72,7 +84,7 @@ def merge_chains(system, chains, all_chains):
class MergeChains(Processor):
name = 'MergeChains'

def __init__(self, chains, all_chains):
def __init__(self, chains=None, all_chains=False):
self.chains = chains
self.all_chains = all_chains

Expand Down
121 changes: 121 additions & 0 deletions vermouth/tests/test_merge_chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2018 University of Groningen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains unittests for vermouth.processors.merge_chains.
"""

import networkx as nx
import pytest
from vermouth.system import System
from vermouth.molecule import Molecule
from vermouth.forcefield import ForceField
from vermouth.processors.merge_chains import (
MergeChains
)
from vermouth.tests.datafiles import (
FF_UNIVERSAL_TEST,
)

@pytest.mark.parametrize('node_data, edge_data, merger, expected', [
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": ["A", "B"], "all_chains": False},
False
),
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": None, "all_chains": True},
False
),
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": ["A", "B"], "all_chains": True},
True
),
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': None, 'resname': 'ALA', 'resid': 1},
{'chain': None, 'resname': 'ALA', 'resid': 2},
{'chain': None, 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": ["A", "B"], "all_chains": True},
True
),
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': None, 'resname': 'ALA', 'resid': 1},
{'chain': None, 'resname': 'ALA', 'resid': 2},
{'chain': None, 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": None, "all_chains": True},
True
),
])
def test_merge(caplog, node_data, edge_data, merger, expected):
"""
Tests that the merging works as expected.
"""
system = System(force_field=ForceField(FF_UNIVERSAL_TEST))
mol = Molecule(force_field=system.force_field)
mol.add_nodes_from(enumerate(node_data))
mol.add_edges_from(edge_data)

mols = nx.connected_components(mol)
for nodes in mols:
system.add_molecule(mol.subgraph(nodes))

processor = MergeChains()
processor.chains = merger["chains"]
processor.all_chains = merger["all_chains"]

caplog.clear()
processor.run_system(system)

if expected:
assert any(rec.levelname == 'WARNING' for rec in caplog.records)
else:
assert caplog.records == []

0 comments on commit 52836f4

Please sign in to comment.