From 6d3afd1a6a7b4249ab690dbdd43c860352ebecd5 Mon Sep 17 00:00:00 2001 From: Chris Brasnett <35073246+csbrasnett@users.noreply.github.com> Date: Mon, 22 Apr 2024 12:16:12 +0100 Subject: [PATCH 1/6] first pass at lazy merge --- bin/martinize2 | 12 +++++++++--- vermouth/processors/merge_chains.py | 16 ++++++++++++---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/bin/martinize2 b/bin/martinize2 index f6e4f8d3b..a837216bc 100755 --- a/bin/martinize2 +++ b/bin/martinize2 @@ -319,7 +319,8 @@ def entry(): file_group.add_argument( "-merge", dest="merge_chains", - type=lambda x: x.split(","), + # type=lambda x: x.split(","), + type=str, action="append", help="Merge chains: e.g. -merge A,B,C (+)", ) @@ -968,8 +969,13 @@ def entry(): itp_paths = [] # Merge chains if required. if args.merge_chains: - for chain_set in args.merge_chains: - vermouth.MergeChains(chain_set).run_system(system) + if args.merge_chains[0] != 'all': + input_chain_sets = [i.split(",") for i in args.merge_chains] + for chain_set in input_chain_sets: + vermouth.MergeChains(chains=chain_set, all_chains=False).run_system(system) + else: + vermouth.MergeChains(chains=None, all_chains=True).run_system(system) + vermouth.NameMolType(deduplicate=not args.keep_duplicate_itp).run_system(system) defines = () diff --git a/vermouth/processors/merge_chains.py b/vermouth/processors/merge_chains.py index 961f87b34..02b98c027 100644 --- a/vermouth/processors/merge_chains.py +++ b/vermouth/processors/merge_chains.py @@ -22,7 +22,7 @@ from ..processors.processor import Processor -def merge_chains(system, chains): +def merge_chains(system, chains, all_chains): """ Merge molecules with the given chains as a single molecule. @@ -42,8 +42,15 @@ def merge_chains(system, chains): The system to modify. chains: list[str] A container of chain identifier. + all_chains: """ - chains = set(chains) + if not all_chains: + chains = set(chains) + else: + l = [] + for molecule in system.molecules: + l.append([node.get('chain') for node in molecule.nodes.values()][0]) + chains = set(l) merged = Molecule() merged._force_field = system.force_field has_merged = False @@ -65,8 +72,9 @@ def merge_chains(system, chains): class MergeChains(Processor): name = 'MergeChains' - def __init__(self, chains): + def __init__(self, chains, all_chains): self.chains = chains + self.all_chains = all_chains def run_system(self, system): - merge_chains(system, self.chains) + merge_chains(system, self.chains, self.all_chains) From 964050af776d90a688ecedd3f18160a39c6f9a19 Mon Sep 17 00:00:00 2001 From: csbrasnett Date: Tue, 23 Apr 2024 10:38:11 +0200 Subject: [PATCH 2/6] added some more info to the help --- bin/martinize2 | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bin/martinize2 b/bin/martinize2 index a837216bc..14caec2d4 100755 --- a/bin/martinize2 +++ b/bin/martinize2 @@ -319,10 +319,10 @@ def entry(): file_group.add_argument( "-merge", dest="merge_chains", - # type=lambda x: x.split(","), type=str, action="append", - help="Merge chains: e.g. -merge A,B,C (+)", + help="Merge chains: either a comma separated list of chains to merge e.g. -merge A,B,C (+) or all.\n" + "Can be given multiple times for different groups of chains to merge.", ) file_group.add_argument( "-resid", @@ -969,7 +969,7 @@ def entry(): itp_paths = [] # Merge chains if required. if args.merge_chains: - if args.merge_chains[0] != 'all': + if args.merge_chains[0] != "all": input_chain_sets = [i.split(",") for i in args.merge_chains] for chain_set in input_chain_sets: vermouth.MergeChains(chains=chain_set, all_chains=False).run_system(system) From 52836f4eb3139dcabc1044d9246755ef9c63f566 Mon Sep 17 00:00:00 2001 From: csbrasnett Date: Tue, 23 Apr 2024 13:09:09 +0200 Subject: [PATCH 3/6] addressed comments, added tests --- bin/martinize2 | 7 +- vermouth/processors/merge_chains.py | 30 ++++--- vermouth/tests/test_merge_chains.py | 121 ++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+), 12 deletions(-) create mode 100644 vermouth/tests/test_merge_chains.py diff --git a/bin/martinize2 b/bin/martinize2 index 14caec2d4..838f8f2b1 100755 --- a/bin/martinize2 +++ b/bin/martinize2 @@ -969,13 +969,14 @@ def entry(): itp_paths = [] # Merge chains if required. if args.merge_chains: - if args.merge_chains[0] != "all": + if "all" not in args.merge_chains: input_chain_sets = [i.split(",") for i in args.merge_chains] for chain_set in input_chain_sets: vermouth.MergeChains(chains=chain_set, all_chains=False).run_system(system) - else: + elif "all" in args.merge_chains and len(args.merge_chains) == 1: vermouth.MergeChains(chains=None, all_chains=True).run_system(system) - + else: + LOGGER.warning("Multiple conflicting merging arguments given. Please check input arguments.") vermouth.NameMolType(deduplicate=not args.keep_duplicate_itp).run_system(system) defines = () diff --git a/vermouth/processors/merge_chains.py b/vermouth/processors/merge_chains.py index 02b98c027..b2b3148b3 100644 --- a/vermouth/processors/merge_chains.py +++ b/vermouth/processors/merge_chains.py @@ -20,6 +20,8 @@ from ..molecule import Molecule from ..processors.processor import Processor +from ..log_helpers import StyleAdapter, get_logger +LOGGER = StyleAdapter(get_logger(__name__)) def merge_chains(system, chains, all_chains): @@ -42,22 +44,32 @@ def merge_chains(system, chains, all_chains): The system to modify. chains: list[str] A container of chain identifier. - all_chains: + all_chains: bool + If True, all chains will be merged. """ - if not all_chains: - chains = set(chains) - else: - l = [] + if not all_chains and len(chains)>0: + _chains = set(chains) + elif all_chains and chains is None: + _chains = set() for molecule in system.molecules: - l.append([node.get('chain') for node in molecule.nodes.values()][0]) - chains = set(l) + # Molecules can contain multiple chains + _chains.update(node.get('chain') for node in molecule.nodes.values()) + else: + LOGGER.warning('Conflicting merging arguments have been given. Will abort merging.') + return + + try: + assert ''.join(list(_chains)).isalnum() + except TypeError: + LOGGER.warning('One or more of your chains does not have a chain identifier in input file.') + merged = Molecule() merged._force_field = system.force_field has_merged = False new_molecules = [] for molecule in system.molecules: molecule_chains = set(node.get('chain') for node in molecule.nodes.values()) - if molecule_chains.issubset(chains): + if molecule_chains.issubset(_chains): if not has_merged: merged.nrexcl = molecule.nrexcl new_molecules.append(merged) @@ -72,7 +84,7 @@ def merge_chains(system, chains, all_chains): class MergeChains(Processor): name = 'MergeChains' - def __init__(self, chains, all_chains): + def __init__(self, chains=None, all_chains=False): self.chains = chains self.all_chains = all_chains diff --git a/vermouth/tests/test_merge_chains.py b/vermouth/tests/test_merge_chains.py new file mode 100644 index 000000000..ce1441a40 --- /dev/null +++ b/vermouth/tests/test_merge_chains.py @@ -0,0 +1,121 @@ +# Copyright 2018 University of Groningen +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains unittests for vermouth.processors.merge_chains. +""" + +import networkx as nx +import pytest +from vermouth.system import System +from vermouth.molecule import Molecule +from vermouth.forcefield import ForceField +from vermouth.processors.merge_chains import ( + MergeChains +) +from vermouth.tests.datafiles import ( + FF_UNIVERSAL_TEST, +) + +@pytest.mark.parametrize('node_data, edge_data, merger, expected', [ + ( + [ + {'chain': 'A', 'resname': 'ALA', 'resid': 1}, + {'chain': 'A', 'resname': 'ALA', 'resid': 2}, + {'chain': 'A', 'resname': 'ALA', 'resid': 3}, + {'chain': 'B', 'resname': 'ALA', 'resid': 1}, + {'chain': 'B', 'resname': 'ALA', 'resid': 2}, + {'chain': 'B', 'resname': 'ALA', 'resid': 3} + ], + [(0, 1), (1, 2), (3, 4), (4, 5)], + {"chains": ["A", "B"], "all_chains": False}, + False + ), + ( + [ + {'chain': 'A', 'resname': 'ALA', 'resid': 1}, + {'chain': 'A', 'resname': 'ALA', 'resid': 2}, + {'chain': 'A', 'resname': 'ALA', 'resid': 3}, + {'chain': 'B', 'resname': 'ALA', 'resid': 1}, + {'chain': 'B', 'resname': 'ALA', 'resid': 2}, + {'chain': 'B', 'resname': 'ALA', 'resid': 3} + ], + [(0, 1), (1, 2), (3, 4), (4, 5)], + {"chains": None, "all_chains": True}, + False + ), + ( + [ + {'chain': 'A', 'resname': 'ALA', 'resid': 1}, + {'chain': 'A', 'resname': 'ALA', 'resid': 2}, + {'chain': 'A', 'resname': 'ALA', 'resid': 3}, + {'chain': 'B', 'resname': 'ALA', 'resid': 1}, + {'chain': 'B', 'resname': 'ALA', 'resid': 2}, + {'chain': 'B', 'resname': 'ALA', 'resid': 3} + ], + [(0, 1), (1, 2), (3, 4), (4, 5)], + {"chains": ["A", "B"], "all_chains": True}, + True + ), + ( + [ + {'chain': 'A', 'resname': 'ALA', 'resid': 1}, + {'chain': 'A', 'resname': 'ALA', 'resid': 2}, + {'chain': 'A', 'resname': 'ALA', 'resid': 3}, + {'chain': None, 'resname': 'ALA', 'resid': 1}, + {'chain': None, 'resname': 'ALA', 'resid': 2}, + {'chain': None, 'resname': 'ALA', 'resid': 3} + ], + [(0, 1), (1, 2), (3, 4), (4, 5)], + {"chains": ["A", "B"], "all_chains": True}, + True + ), + ( + [ + {'chain': 'A', 'resname': 'ALA', 'resid': 1}, + {'chain': 'A', 'resname': 'ALA', 'resid': 2}, + {'chain': 'A', 'resname': 'ALA', 'resid': 3}, + {'chain': None, 'resname': 'ALA', 'resid': 1}, + {'chain': None, 'resname': 'ALA', 'resid': 2}, + {'chain': None, 'resname': 'ALA', 'resid': 3} + ], + [(0, 1), (1, 2), (3, 4), (4, 5)], + {"chains": None, "all_chains": True}, + True + ), + +]) +def test_merge(caplog, node_data, edge_data, merger, expected): + """ + Tests that the merging works as expected. + """ + system = System(force_field=ForceField(FF_UNIVERSAL_TEST)) + mol = Molecule(force_field=system.force_field) + mol.add_nodes_from(enumerate(node_data)) + mol.add_edges_from(edge_data) + + mols = nx.connected_components(mol) + for nodes in mols: + system.add_molecule(mol.subgraph(nodes)) + + processor = MergeChains() + processor.chains = merger["chains"] + processor.all_chains = merger["all_chains"] + + caplog.clear() + processor.run_system(system) + + if expected: + assert any(rec.levelname == 'WARNING' for rec in caplog.records) + else: + assert caplog.records == [] \ No newline at end of file From 444719f8af427f6af77762417bf00e535c27b8e7 Mon Sep 17 00:00:00 2001 From: csbrasnett Date: Tue, 23 Apr 2024 17:17:47 +0200 Subject: [PATCH 4/6] addressed comments - made stronger warnings in martinize2 - clarified conditions for merging in merge_chains - improved warnings and exceptions in merge_chains - changed tests to reflect abovex --- bin/martinize2 | 16 ++++--- vermouth/processors/merge_chains.py | 13 +++--- vermouth/tests/test_merge_chains.py | 65 ++++++++++++++++------------- 3 files changed, 52 insertions(+), 42 deletions(-) diff --git a/bin/martinize2 b/bin/martinize2 index f52542326..879d89f25 100755 --- a/bin/martinize2 +++ b/bin/martinize2 @@ -316,13 +316,14 @@ def entry(): default=False, help="Write separate topologies for identical chains", ) - file_group.add_argument( + chain_merging = file_group.add_argument( "-merge", dest="merge_chains", type=str, action="append", - help="Merge chains: either a comma separated list of chains to merge e.g. -merge A,B,C (+) or all.\n" - "Can be given multiple times for different groups of chains to merge.", + help="Merge chains: either a comma separated list of chains to merge e.g. -merge A,B,C (+), or -merge all\n" + "if instead all chains in the input file should be merged.\n" + "Can be given multiple times for different groups of chains to merge.", ) file_group.add_argument( "-resid", @@ -973,14 +974,19 @@ def entry(): itp_paths = [] # Merge chains if required. if args.merge_chains: + #if "all" is not in the list of chains to be merged if "all" not in args.merge_chains: input_chain_sets = [i.split(",") for i in args.merge_chains] for chain_set in input_chain_sets: vermouth.MergeChains(chains=chain_set, all_chains=False).run_system(system) + #if "all" is in the list and is the only argument elif "all" in args.merge_chains and len(args.merge_chains) == 1: - vermouth.MergeChains(chains=None, all_chains=True).run_system(system) + vermouth.MergeChains(chains=[], all_chains=True).run_system(system) + #otherwise there are multiple arguments and we need to raise an ArgumentError else: - LOGGER.warning("Multiple conflicting merging arguments given. Please check input arguments.") + raise argparse.ArgumentError(chain_merging, + message=("Multiple conflicting merging arguments given. " + "Either specify -merge all or -merge A,B,C (+).")) vermouth.NameMolType(deduplicate=not args.keep_duplicate_itp).run_system(system) defines = () diff --git a/vermouth/processors/merge_chains.py b/vermouth/processors/merge_chains.py index b2b3148b3..b4c9e2eb6 100644 --- a/vermouth/processors/merge_chains.py +++ b/vermouth/processors/merge_chains.py @@ -47,20 +47,17 @@ def merge_chains(system, chains, all_chains): all_chains: bool If True, all chains will be merged. """ - if not all_chains and len(chains)>0: + if not all_chains and len(chains) > 0: _chains = set(chains) - elif all_chains and chains is None: + elif all_chains and len(chains) == 0: _chains = set() for molecule in system.molecules: # Molecules can contain multiple chains _chains.update(node.get('chain') for node in molecule.nodes.values()) else: - LOGGER.warning('Conflicting merging arguments have been given. Will abort merging.') - return + raise ValueError - try: - assert ''.join(list(_chains)).isalnum() - except TypeError: + if any(not c for c in _chains): LOGGER.warning('One or more of your chains does not have a chain identifier in input file.') merged = Molecule() @@ -84,7 +81,7 @@ def merge_chains(system, chains, all_chains): class MergeChains(Processor): name = 'MergeChains' - def __init__(self, chains=None, all_chains=False): + def __init__(self, chains=[], all_chains=False): self.chains = chains self.all_chains = all_chains diff --git a/vermouth/tests/test_merge_chains.py b/vermouth/tests/test_merge_chains.py index ce1441a40..31cefc083 100644 --- a/vermouth/tests/test_merge_chains.py +++ b/vermouth/tests/test_merge_chains.py @@ -51,22 +51,9 @@ {'chain': 'B', 'resname': 'ALA', 'resid': 3} ], [(0, 1), (1, 2), (3, 4), (4, 5)], - {"chains": None, "all_chains": True}, + {"chains": [], "all_chains": True}, False ), - ( - [ - {'chain': 'A', 'resname': 'ALA', 'resid': 1}, - {'chain': 'A', 'resname': 'ALA', 'resid': 2}, - {'chain': 'A', 'resname': 'ALA', 'resid': 3}, - {'chain': 'B', 'resname': 'ALA', 'resid': 1}, - {'chain': 'B', 'resname': 'ALA', 'resid': 2}, - {'chain': 'B', 'resname': 'ALA', 'resid': 3} - ], - [(0, 1), (1, 2), (3, 4), (4, 5)], - {"chains": ["A", "B"], "all_chains": True}, - True - ), ( [ {'chain': 'A', 'resname': 'ALA', 'resid': 1}, @@ -77,20 +64,7 @@ {'chain': None, 'resname': 'ALA', 'resid': 3} ], [(0, 1), (1, 2), (3, 4), (4, 5)], - {"chains": ["A", "B"], "all_chains": True}, - True - ), - ( - [ - {'chain': 'A', 'resname': 'ALA', 'resid': 1}, - {'chain': 'A', 'resname': 'ALA', 'resid': 2}, - {'chain': 'A', 'resname': 'ALA', 'resid': 3}, - {'chain': None, 'resname': 'ALA', 'resid': 1}, - {'chain': None, 'resname': 'ALA', 'resid': 2}, - {'chain': None, 'resname': 'ALA', 'resid': 3} - ], - [(0, 1), (1, 2), (3, 4), (4, 5)], - {"chains": None, "all_chains": True}, + {"chains": [], "all_chains": True}, True ), @@ -118,4 +92,37 @@ def test_merge(caplog, node_data, edge_data, merger, expected): if expected: assert any(rec.levelname == 'WARNING' for rec in caplog.records) else: - assert caplog.records == [] \ No newline at end of file + assert caplog.records == [] + +def test_too_many_args(): + """ + Tests that error is raised when too many arguments are given. + """ + node_data = [ + {'chain': 'A', 'resname': 'ALA', 'resid': 1}, + {'chain': 'A', 'resname': 'ALA', 'resid': 2}, + {'chain': 'A', 'resname': 'ALA', 'resid': 3}, + {'chain': 'B', 'resname': 'ALA', 'resid': 1}, + {'chain': 'B', 'resname': 'ALA', 'resid': 2}, + {'chain': 'B', 'resname': 'ALA', 'resid': 3} + ] + edge_data = [(0, 1), (1, 2), (3, 4), (4, 5)] + + system = System(force_field=ForceField(FF_UNIVERSAL_TEST)) + mol = Molecule(force_field=system.force_field) + mol.add_nodes_from(enumerate(node_data)) + mol.add_edges_from(edge_data) + + mols = nx.connected_components(mol) + for nodes in mols: + system.add_molecule(mol.subgraph(nodes)) + + merger = {"chains": ["A", "B"], "all_chains": True} + + processor = MergeChains() + processor.chains = merger["chains"] + processor.all_chains = merger["all_chains"] + + with pytest.raises(ValueError): + processor.run_system(system) + From e0a9c0a0c55dc91888a4c44c156e7b3f97c6fdb8 Mon Sep 17 00:00:00 2001 From: csbrasnett Date: Wed, 24 Apr 2024 17:26:53 +0200 Subject: [PATCH 5/6] changed suggestions --- vermouth/processors/merge_chains.py | 6 +++--- vermouth/tests/test_merge_chains.py | 5 +---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vermouth/processors/merge_chains.py b/vermouth/processors/merge_chains.py index b4c9e2eb6..4ce5ab570 100644 --- a/vermouth/processors/merge_chains.py +++ b/vermouth/processors/merge_chains.py @@ -55,7 +55,7 @@ def merge_chains(system, chains, all_chains): # Molecules can contain multiple chains _chains.update(node.get('chain') for node in molecule.nodes.values()) else: - raise ValueError + raise ValueError("Can specify specific chains or all chains, but not both") if any(not c for c in _chains): LOGGER.warning('One or more of your chains does not have a chain identifier in input file.') @@ -81,8 +81,8 @@ def merge_chains(system, chains, all_chains): class MergeChains(Processor): name = 'MergeChains' - def __init__(self, chains=[], all_chains=False): - self.chains = chains + def __init__(self, chains=None, all_chains=False): + self.chains = chains or [] self.all_chains = all_chains def run_system(self, system): diff --git a/vermouth/tests/test_merge_chains.py b/vermouth/tests/test_merge_chains.py index 31cefc083..2a260eff6 100644 --- a/vermouth/tests/test_merge_chains.py +++ b/vermouth/tests/test_merge_chains.py @@ -82,10 +82,7 @@ def test_merge(caplog, node_data, edge_data, merger, expected): for nodes in mols: system.add_molecule(mol.subgraph(nodes)) - processor = MergeChains() - processor.chains = merger["chains"] - processor.all_chains = merger["all_chains"] - + processor = MergeChains(**merger) caplog.clear() processor.run_system(system) From d5ee2ddfdd18a95637fdfee31b0ac042b1cc6085 Mon Sep 17 00:00:00 2001 From: csbrasnett Date: Wed, 24 Apr 2024 17:28:20 +0200 Subject: [PATCH 6/6] missed one --- vermouth/tests/test_merge_chains.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vermouth/tests/test_merge_chains.py b/vermouth/tests/test_merge_chains.py index 2a260eff6..c2db94f5a 100644 --- a/vermouth/tests/test_merge_chains.py +++ b/vermouth/tests/test_merge_chains.py @@ -116,9 +116,7 @@ def test_too_many_args(): merger = {"chains": ["A", "B"], "all_chains": True} - processor = MergeChains() - processor.chains = merger["chains"] - processor.all_chains = merger["all_chains"] + processor = MergeChains(**merger) with pytest.raises(ValueError): processor.run_system(system)