Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy merge #589

Merged
merged 8 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions bin/martinize2
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ def entry():
default=False,
help="Write separate topologies for identical chains",
)
file_group.add_argument(
chain_merging = file_group.add_argument(
"-merge",
dest="merge_chains",
type=lambda x: x.split(","),
type=str,
action="append",
help="Merge chains: e.g. -merge A,B,C (+)",
help="Merge chains: either a comma separated list of chains to merge e.g. -merge A,B,C (+), or -merge all\n"
"if instead all chains in the input file should be merged.\n"
"Can be given multiple times for different groups of chains to merge.",
)
file_group.add_argument(
"-resid",
Expand Down Expand Up @@ -974,8 +976,19 @@ def entry():
itp_paths = []
# Merge chains if required.
if args.merge_chains:
for chain_set in args.merge_chains:
vermouth.MergeChains(chain_set).run_system(system)
#if "all" is not in the list of chains to be merged
if "all" not in args.merge_chains:
input_chain_sets = [i.split(",") for i in args.merge_chains]
for chain_set in input_chain_sets:
vermouth.MergeChains(chains=chain_set, all_chains=False).run_system(system)
#if "all" is in the list and is the only argument
elif "all" in args.merge_chains and len(args.merge_chains) == 1:
vermouth.MergeChains(chains=[], all_chains=True).run_system(system)
#otherwise there are multiple arguments and we need to raise an ArgumentError
else:
raise argparse.ArgumentError(chain_merging,
message=("Multiple conflicting merging arguments given. "
"Either specify -merge all or -merge A,B,C (+)."))
vermouth.NameMolType(deduplicate=not args.keep_duplicate_itp).run_system(system)
defines = ()

Expand Down
29 changes: 23 additions & 6 deletions vermouth/processors/merge_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

from ..molecule import Molecule
from ..processors.processor import Processor
from ..log_helpers import StyleAdapter, get_logger
LOGGER = StyleAdapter(get_logger(__name__))


def merge_chains(system, chains):
def merge_chains(system, chains, all_chains):
"""
Merge molecules with the given chains as a single molecule.

Expand All @@ -42,15 +44,29 @@ def merge_chains(system, chains):
The system to modify.
chains: list[str]
A container of chain identifier.
all_chains: bool
If True, all chains will be merged.
"""
chains = set(chains)
if not all_chains and len(chains) > 0:
_chains = set(chains)
elif all_chains and len(chains) == 0:
_chains = set()
for molecule in system.molecules:
# Molecules can contain multiple chains
_chains.update(node.get('chain') for node in molecule.nodes.values())
else:
raise ValueError("Can specify specific chains or all chains, but not both")

if any(not c for c in _chains):
LOGGER.warning('One or more of your chains does not have a chain identifier in input file.')

merged = Molecule()
merged._force_field = system.force_field
has_merged = False
new_molecules = []
for molecule in system.molecules:
molecule_chains = set(node.get('chain') for node in molecule.nodes.values())
if molecule_chains.issubset(chains):
if molecule_chains.issubset(_chains):
if not has_merged:
merged.nrexcl = molecule.nrexcl
new_molecules.append(merged)
Expand All @@ -65,8 +81,9 @@ def merge_chains(system, chains):
class MergeChains(Processor):
name = 'MergeChains'

def __init__(self, chains):
self.chains = chains
def __init__(self, chains=None, all_chains=False):
self.chains = chains or []
self.all_chains = all_chains

def run_system(self, system):
merge_chains(system, self.chains)
merge_chains(system, self.chains, self.all_chains)
123 changes: 123 additions & 0 deletions vermouth/tests/test_merge_chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2018 University of Groningen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains unittests for vermouth.processors.merge_chains.
"""

import networkx as nx
import pytest
from vermouth.system import System
from vermouth.molecule import Molecule
from vermouth.forcefield import ForceField
from vermouth.processors.merge_chains import (
MergeChains
)
from vermouth.tests.datafiles import (
FF_UNIVERSAL_TEST,
)

@pytest.mark.parametrize('node_data, edge_data, merger, expected', [
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": ["A", "B"], "all_chains": False},
False
),
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": [], "all_chains": True},
False
),
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': None, 'resname': 'ALA', 'resid': 1},
{'chain': None, 'resname': 'ALA', 'resid': 2},
{'chain': None, 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": [], "all_chains": True},
True
),

])
def test_merge(caplog, node_data, edge_data, merger, expected):
"""
Tests that the merging works as expected.
"""
system = System(force_field=ForceField(FF_UNIVERSAL_TEST))
mol = Molecule(force_field=system.force_field)
mol.add_nodes_from(enumerate(node_data))
mol.add_edges_from(edge_data)

mols = nx.connected_components(mol)
for nodes in mols:
system.add_molecule(mol.subgraph(nodes))

processor = MergeChains(**merger)
caplog.clear()
processor.run_system(system)

if expected:
assert any(rec.levelname == 'WARNING' for rec in caplog.records)
else:
assert caplog.records == []

def test_too_many_args():
"""
Tests that error is raised when too many arguments are given.
"""
node_data = [
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
]
edge_data = [(0, 1), (1, 2), (3, 4), (4, 5)]

system = System(force_field=ForceField(FF_UNIVERSAL_TEST))
mol = Molecule(force_field=system.force_field)
mol.add_nodes_from(enumerate(node_data))
mol.add_edges_from(edge_data)

mols = nx.connected_components(mol)
for nodes in mols:
system.add_molecule(mol.subgraph(nodes))

merger = {"chains": ["A", "B"], "all_chains": True}

processor = MergeChains(**merger)

with pytest.raises(ValueError):
processor.run_system(system)

Loading