Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the n_protocol_dag_results property to ProtocolResult #381

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions gufe/protocols/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import abc
import warnings
from collections.abc import Iterable
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Sized

from openff.units import Quantity

Expand All @@ -19,7 +19,6 @@
from .protocoldag import ProtocolDAG, ProtocolDAGResult
from .protocolunit import ProtocolUnit


class ProtocolResult(GufeTokenizable):
"""
Container for all results for a single :class:`Transformation`.
Expand All @@ -33,19 +32,33 @@ class ProtocolResult(GufeTokenizable):
- `get_uncertainty`
"""

def __init__(self, **data):
def __init__(self, n_protocol_dag_results: int = 0, **data):
self._data = data

if not n_protocol_dag_results >= 0:
raise ValueError("`n_protocol_dag_results` must be an integer greater than or equal to zero")

self._n_protocol_dag_results = n_protocol_dag_results

@classmethod
def _defaults(cls):
return {}

def _to_dict(self):
return {"data": self.data}
return {'n_protocol_dag_results': self.n_protocol_dag_results, 'data': self.data}

@classmethod
def _from_dict(cls, dct: dict):
return cls(**dct["data"])
# TODO: remove in gufe 2.0
try:
n_protocol_dag_results = dct['n_protocol_dag_results']
except KeyError:
n_protocol_dag_results = 0
return cls(n_protocol_dag_results=n_protocol_dag_results, **dct['data'])

@property
def n_protocol_dag_results(self) -> int:
return self._n_protocol_dag_results

@property
def data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -260,7 +273,14 @@ def gather(
ProtocolResult
Aggregated results from many `ProtocolDAGResult`s from a given `Protocol`.
"""
return self.result_cls(**self._gather(protocol_dag_results))
# Iterable does not implement __len__ and makes no guarantees that
# protocol_dag_results is finite, checking both in method signature
# doesn't appear possible, explicitly check for __len__ through the
# Sized type
if not isinstance(protocol_dag_results, Sized):
raise ValueError("`protocol_dag_results` must implement `__len__`")
return self.result_cls(n_protocol_dag_results=len(protocol_dag_results),
**self._gather(protocol_dag_results))

@abc.abstractmethod
def _gather(
Expand Down
34 changes: 30 additions & 4 deletions gufe/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# For details, see https://github.com/OpenFreeEnergy/gufe
import datetime
import itertools

from typing import Optional, Iterable, List, Dict, Any, Union, Sized


import pathlib
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Union


import networkx as nx
import numpy as np
Expand Down Expand Up @@ -366,15 +370,37 @@ def test_create_execute_gather(self, protocol_dag):
# gather aggregated results of interest
protocolresult = protocol.gather([dagresult])

assert len(protocolresult.data["logs"]) == 1
assert len(protocolresult.data["logs"][0]) == 21 + 1
assert protocolresult.n_protocol_dag_results == 1
assert len(protocolresult.data['logs']) == 1
assert len(protocolresult.data['logs'][0]) == 21 + 1

assert protocolresult.get_estimate() == 95500.0

def test_gather_infinite_iterable_guardrail(self, protocol_dag):
protocol, dag, dagresult = protocol_dag

assert dagresult.ok()

# we want an infinite generator, but one that would actually stop early in case
# the guardrail doesn't work, but the type system doesn't know that
def infinite_generator():
while True:
yield dag
break

gen = infinite_generator()
assert isinstance(gen, Iterable)
assert not isinstance(gen, Sized)

with pytest.raises(ValueError, match="`protocol_dag_results` must implement `__len__`"):
protocol.gather(infinite_generator())

def test_deprecation_warning_on_dict_mapping(
self, instance, vacuum_ligand, solvated_ligand
):
lig = solvated_ligand.components["ligand"]
lig = solvated_ligand.components['ligand']


mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={})

with pytest.warns(
Expand Down
30 changes: 30 additions & 0 deletions gufe/tests/test_protocolresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,33 @@ def instance(self):
estimate=4.2 * unit.kilojoule_per_mole,
uncertainty=0.2 * unit.kilojoule_per_mole,
)

def test_protocolresult_get_estimate(self, instance):
assert instance.get_estimate() == 4.2 * unit.kilojoule_per_mole

def test_protocolresult_get_uncertainty(self, instance):
assert instance.get_uncertainty() == 0.2 * unit.kilojoule_per_mole

def test_protocolresult_default_n_protocol_dag_results(self, instance):
assert instance.n_protocol_dag_results == 0

def test_protocol_result_from_dict_missing_n_protocol_dag_results(self, instance):
protocol_result_dict_form = instance.to_dict()
assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance
del protocol_result_dict_form['n_protocol_dag_results']
assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance

@pytest.mark.parametrize(
"arg, expected", [(0, 0), (1, 1), (-1, ValueError)]
)
def test_protocolresult_get_n_protocol_dag_results_args(self, arg, expected):
try:
protocol_result = DummyProtocolResult(
n_protocol_dag_results=arg,
estimate=4.2 * unit.kilojoule_per_mole,
uncertainty=0.2 * unit.kilojoule_per_mole,
)
assert protocol_result.n_protocol_dag_results == expected
except ValueError:
if expected is not ValueError:
raise AssertionError()