Skip to content

Commit

Permalink
Save the switching trajectory into files to avoid OOM issue (#91)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

* fix test

---------

Co-authored-by: William (Zhiyi) Wu <[email protected]>
Co-authored-by: Marcus Wieder <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2023
1 parent 672ddbd commit 4325d3b
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 123 deletions.
45 changes: 34 additions & 11 deletions endstate_correction/neq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Provide the functions for non-equilibrium switching."""


import os
import pickle
import random
from typing import Tuple
Expand All @@ -22,7 +21,8 @@ def perform_switching(
nr_of_switches: int = 50,
save_trajs: bool = False,
save_endstates: bool = False,
) -> Tuple[list, list, list]:
workdir: str = ".",
) -> Tuple[list, list]:
"""Perform NEQ switching using the provided lambda schema on the passed simulation instance.
Args:
Expand All @@ -37,9 +37,9 @@ def perform_switching(
RuntimeError: if the number of lambda states is less than 2
Returns:
Tuple[list, list, list]: work values, endstate samples, switching trajectories
Tuple[list, list]: work values, endstate samples
"""

os.makedirs(workdir, exist_ok=True)
if save_endstates:
print("Endstate of each switch will be saved.")
if save_trajs:
Expand All @@ -62,7 +62,7 @@ def perform_switching(
print("NEQ switching: dW will be calculated")

# start with switch
for _ in tqdm(range(nr_of_switches)):
for switch_index in tqdm(range(nr_of_switches)):
if save_trajs:
# if switching trajectories need to be saved, create an empty list at the beginning
# of each switch for saving conformations
Expand Down Expand Up @@ -91,7 +91,9 @@ def perform_switching(
sim.context.setParameter("lambda_interpolate", lambdas[idx_lamb])
if save_trajs:
# save conformation at the beginning of each switch
switching_trajectory.append(get_positions(sim))
switching_trajectory.append(
get_positions(sim).value_in_unit(unit.nanometer)
)
# test if neq or instantaneous swithching: if neq, perform integration step
if not inst_switching:
# perform 1 simulation step
Expand All @@ -108,9 +110,31 @@ def perform_switching(
# TODO: expand to reduced potential
if save_trajs:
# at the end of each switch save the last conformation
switching_trajectory.append(get_positions(sim))
# collect all switching trajectories as a list of lists
all_switching_trajectories.append(switching_trajectory)
switching_trajectory.append(
get_positions(sim).value_in_unit(unit.nanometer)
)

topology = samples.topology
unitcell_lengths = samples[0].unitcell_lengths
unitcell_angles = samples[0].unitcell_lengths
switching_trajectory_length = len(switching_trajectory)
if unitcell_lengths is None:
switching_trajectory = Trajectory(
topology=topology,
xyz=np.stack(switching_trajectory),
)
else:
switching_trajectory = Trajectory(
topology=topology,
xyz=np.stack(switching_trajectory),
unitcell_lengths=np.ones((switching_trajectory_length, 3))
* unitcell_lengths,
unitcell_angles=np.ones((switching_trajectory_length, 3))
* unitcell_angles,
)
switching_trajectory.save(
f"{workdir}/switching_trajectory_{switch_index}.dcd"
)
if save_endstates:
# save the endstate conformation
endstate_samples.append(get_positions(sim))
Expand All @@ -119,7 +143,6 @@ def perform_switching(
return (
np.array(ws) * unit.kilojoule_per_mole,
endstate_samples,
all_switching_trajectories,
)


Expand Down
89 changes: 57 additions & 32 deletions endstate_correction/protocol.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Provide functions for the endstate correction workflow."""


import os
from dataclasses import dataclass, field
from typing import List, Optional, Union

import mdtraj as md
import numpy as np
import pandas as pd
from openmm import unit
from openmm.app import Simulation
from pymbar import MBAR
Expand Down Expand Up @@ -145,10 +143,10 @@ class SMCProtocol(BaseProtocol):
"""SMC-specific protocol"""

nr_of_walkers: int = -1 # number of walkers for SMC
protocol_length: int = 1_000 # length of the SMC protocol
nr_of_resampling_steps: int = 1_000 # number of times walkers are resampled
protocol_length: int = 1_000 # length of the SMC protocol
nr_of_resampling_steps: int = 1_000 # number of times walkers are resampled
save_endstates: bool = False

def __setattr__(self, prop, val):
if prop == "protocol_length":
self._check_protocol_length(val)
Expand All @@ -158,20 +156,22 @@ def __setattr__(self, prop, val):

@staticmethod
def _check_protocol_length(protocol_length):
if protocol_length%10 != 0:
if protocol_length % 10 != 0:
raise ValueError("Protocol length has to be a multiple factor of 10!")

@staticmethod
def _check_nr_of_resampling_steps(nr_of_resampling_steps):
if nr_of_resampling_steps%10 != 0:
raise ValueError("Number of resampling steps has to be a multiple factor of 10!")
if nr_of_resampling_steps % 10 != 0:
raise ValueError(
"Number of resampling steps has to be a multiple factor of 10!"
)

def __post_init__(self):
super().__post_init__() # Call base class's post-init


@dataclass
class AllProtocol():
class AllProtocol:
"""Dataclass for running all protocols"""

fep_protocol: Union[None, FEPProtocol] = None
Expand All @@ -180,10 +180,11 @@ class AllProtocol():

# check if reference or target samples are provided
def __post_init__(self):
self.fep_protocol.__post_init__()
self.fep_protocol.__post_init__()
self.neq_protocol.__post_init__()
self.smc_protocol.__post_init__()


class BaseResults:
"""Base class for all protocol results"""

Expand All @@ -199,31 +200,52 @@ class EquResults(BaseResults):
class FEPResults(BaseResults):
"""FEP-specific results"""

dE_reference_to_target: np.array = field(default_factory=lambda: np.array([])) # dE from reference to target
dE_target_to_reference: np.array = field(default_factory=lambda: np.array([])) # dE from target to reference
dE_reference_to_target: np.array = field(
default_factory=lambda: np.array([])
) # dE from reference to target
dE_target_to_reference: np.array = field(
default_factory=lambda: np.array([])
) # dE from target to reference


@dataclass
class NEQResults(BaseResults):
"""Provides a dataclass containing the results of a protocol"""

W_reference_to_target: np.array = field(default_factory=lambda: np.array([])) # W from reference to target
W_target_to_reference: np.array = field(default_factory=lambda: np.array([])) # W from target to reference
endstate_samples_reference_to_target: np.array = field(default_factory=lambda: np.array([])) # endstate samples from reference to target
endstate_samples_target_to_reference: np.array = field(default_factory=lambda: np.array([])) # endstate samples from target to reference
switching_traj_reference_to_target: np.array = field(default_factory=lambda: np.array([])) # switching traj from reference to target
switching_traj_target_to_reference: np.array = field(default_factory=lambda: np.array([])) # switching traj from target to reference
W_reference_to_target: np.array = field(
default_factory=lambda: np.array([])
) # W from reference to target
W_target_to_reference: np.array = field(
default_factory=lambda: np.array([])
) # W from target to reference
endstate_samples_reference_to_target: np.array = field(
default_factory=lambda: np.array([])
) # endstate samples from reference to target
endstate_samples_target_to_reference: np.array = field(
default_factory=lambda: np.array([])
) # endstate samples from target to reference
switching_traj_reference_to_target: np.array = field(
default_factory=lambda: np.array([])
) # switching traj from reference to target
switching_traj_target_to_reference: np.array = field(
default_factory=lambda: np.array([])
) # switching traj from target to reference


@dataclass
class SMCResults(BaseResults):
logZ: float = 0.0 # free energy difference
effective_sample_size: list = field(default_factory=list) # effective sample size
endstate_samples_reference_to_target: np.array = field(default_factory=lambda: np.array([])) # endstate samples from reference to target
endstate_samples_target_to_reference: np.array = field(default_factory=lambda: np.array([])) # endstate samples from target to reference
endstate_samples_reference_to_target: np.array = field(
default_factory=lambda: np.array([])
) # endstate samples from reference to target
endstate_samples_target_to_reference: np.array = field(
default_factory=lambda: np.array([])
) # endstate samples from target to reference


@dataclass
class AllResults():
class AllResults:
"""Dataclass for combined results of all protocols"""

equ_results: Union[None, EquResults] = None
Expand All @@ -232,19 +254,24 @@ class AllResults():
smc_results: Union[None, SMCResults] = None


def perform_endstate_correction(protocol: Union[BaseProtocol, AllProtocol]) -> AllResults:
def perform_endstate_correction(
protocol: Union[BaseProtocol, AllProtocol], workdir: str = "."
) -> AllResults:
"""Perform endstate correction using the provided protocol.
Args:
protocol (Union[BaseProtocol, AllProtocol]): defines the endstate correction.
protocol (Union[BaseProtocol, AllProtocol]): defines the endstate correction.
Either a specific protocol or a collection of protocols.
workdir: The working directory to save the output.
Returns:
BaseResults: results generated using the passed protocol
"""
from endstate_correction.constant import kBT
from endstate_correction.neq import perform_switching

os.makedirs(workdir, exist_ok=True)

r = AllResults()
if isinstance(protocol, AllProtocol) or isinstance(protocol, FEPProtocol):
print(f"Performing endstate correction using FEP")
Expand All @@ -262,7 +289,7 @@ def perform_endstate_correction(protocol: Union[BaseProtocol, AllProtocol]) -> A
# from reference to target potential
if protocol_.reference_samples is not None: # if reference samples are provided
print("Performing FEP from reference to target potential")
dEs, _, _ = perform_switching(
dEs, _ = perform_switching(
sim,
lambdas=list_of_lambda_values,
samples=protocol_.reference_samples,
Expand All @@ -274,7 +301,7 @@ def perform_endstate_correction(protocol: Union[BaseProtocol, AllProtocol]) -> A
# from target to reference potential
if protocol_.target_samples is not None: # if target samples are provided
print("Performing FEP from target to reference potential")
dEs, _, _ = perform_switching(
dEs, _ = perform_switching(
sim,
lambdas=np.flip(
list_of_lambda_values
Expand Down Expand Up @@ -303,7 +330,7 @@ def perform_endstate_correction(protocol: Union[BaseProtocol, AllProtocol]) -> A
smc_sampler.perform_SMC(
nr_of_walkers=protocol_.nr_of_walkers,
protocol_length=protocol_.protocol_length,
nr_of_resampling_steps=protocol_.nr_of_resampling_steps
nr_of_resampling_steps=protocol_.nr_of_resampling_steps,
)

r_smc.logZ = smc_sampler.logZ
Expand Down Expand Up @@ -332,27 +359,25 @@ def perform_endstate_correction(protocol: Union[BaseProtocol, AllProtocol]) -> A
(
Ws,
endstates_reference_to_target,
trajs_reference_to_target,
) = perform_switching(
sim,
lambdas=list_of_lambda_values,
samples=protocol_.reference_samples,
nr_of_switches=protocol_.nr_of_switches,
save_endstates=protocol_.save_endstates,
save_trajs=protocol_.save_trajs,
workdir=f"{workdir}/reference_to_target",
)
Ws_reference_to_target = np.array(Ws / kBT) # remove units
r_neq.W_reference_to_target = Ws_reference_to_target
r_neq.endstate_samples_reference_to_target = endstates_reference_to_target
r_neq.switching_traj_reference_to_target = trajs_reference_to_target

# from target to reference potential
if protocol_.target_samples is not None:
print("Performing NEQ from target to reference potential")
(
Ws,
endstates_target_to_reference,
trajs_target_to_reference,
) = perform_switching(
sim,
lambdas=np.flip(
Expand All @@ -362,11 +387,11 @@ def perform_endstate_correction(protocol: Union[BaseProtocol, AllProtocol]) -> A
nr_of_switches=protocol_.nr_of_switches,
save_endstates=protocol_.save_endstates,
save_trajs=protocol_.save_trajs,
workdir=f"{workdir}/target_to_reference",
)
Ws_target_to_reference = np.array(Ws / kBT)
r_neq.W_target_to_reference = Ws_target_to_reference
r_neq.endstate_samples_target_to_reference = endstates_target_to_reference
r_neq.switching_traj_target_to_reference = trajs_target_to_reference
r.neq_results = r_neq

return r
Loading

0 comments on commit 4325d3b

Please sign in to comment.