Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the n_protocol_dag_results property to ProtocolResult #381

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions gufe/protocols/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
"""

import abc
from typing import Optional, Iterable, Any, Union
from typing import Optional, Iterable, Any, Union, Sized
from openff.units import Quantity
import warnings

from ..settings import Settings, SettingsBaseModel
from ..settings import Settings
from ..tokenization import GufeTokenizable, GufeKey
from ..chemicalsystem import ChemicalSystem
from ..mapping import ComponentMapping

from .protocoldag import ProtocolDAG, ProtocolDAGResult
from .protocolunit import ProtocolUnit


class ProtocolResult(GufeTokenizable):
"""
Container for all results for a single :class:`Transformation`.
Expand All @@ -32,19 +31,33 @@ class ProtocolResult(GufeTokenizable):
- `get_uncertainty`
"""

def __init__(self, **data):
def __init__(self, n_protocol_dag_results: int = 0, **data):
self._data = data

if not n_protocol_dag_results >= 0:
raise ValueError("`n_protocol_dag_results` must be an integer greater than or equal to zero")

self._n_protocol_dag_results = n_protocol_dag_results

@classmethod
def _defaults(cls):
return {}

def _to_dict(self):
return {'data': self.data}
return {'n_protocol_dag_results': self.n_protocol_dag_results, 'data': self.data}

@classmethod
def _from_dict(cls, dct: dict):
return cls(**dct['data'])
# TODO: remove in gufe 2.0
try:
n_protocol_dag_results = dct['n_protocol_dag_results']
except KeyError:
n_protocol_dag_results = 0
return cls(n_protocol_dag_results=n_protocol_dag_results, **dct['data'])

@property
def n_protocol_dag_results(self) -> int:
return self._n_protocol_dag_results

@property
def data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -254,7 +267,14 @@ def gather(
ProtocolResult
Aggregated results from many `ProtocolDAGResult`s from a given `Protocol`.
"""
return self.result_cls(**self._gather(protocol_dag_results))
# Iterable does not implement __len__ and makes no guarantees that
# protocol_dag_results is finite, checking both in method signature
# doesn't appear possible, explicitly check for __len__ through the
# Sized type
if not isinstance(protocol_dag_results, Sized):
raise ValueError("`protocol_dag_results` must implement `__len__`")
return self.result_cls(n_protocol_dag_results=len(protocol_dag_results),
**self._gather(protocol_dag_results))

@abc.abstractmethod
def _gather(
Expand Down
22 changes: 21 additions & 1 deletion gufe/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
import itertools
from openff.units import unit
from typing import Optional, Iterable, List, Dict, Any, Union
from typing import Optional, Iterable, List, Dict, Any, Union, Sized
from collections import defaultdict
import pathlib

Expand Down Expand Up @@ -328,11 +328,31 @@ def test_create_execute_gather(self, protocol_dag):
# gather aggregated results of interest
protocolresult = protocol.gather([dagresult])

assert protocolresult.n_protocol_dag_results == 1
assert len(protocolresult.data['logs']) == 1
assert len(protocolresult.data['logs'][0]) == 21 + 1

assert protocolresult.get_estimate() == 95500.0

def test_gather_infinite_iterable_guardrail(self, protocol_dag):
protocol, dag, dagresult = protocol_dag

assert dagresult.ok()

# we want an infinite generator, but one that would actually stop early in case
# the guardrail doesn't work, but the type system doesn't know that
def infinite_generator():
while True:
yield dag
break

gen = infinite_generator()
assert isinstance(gen, Iterable)
assert not isinstance(gen, Sized)

with pytest.raises(ValueError, match="`protocol_dag_results` must implement `__len__`"):
protocol.gather(infinite_generator())

def test_deprecation_warning_on_dict_mapping(self, instance, vacuum_ligand, solvated_ligand):
lig = solvated_ligand.components['ligand']
mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={})
Expand Down
34 changes: 32 additions & 2 deletions gufe/tests/test_protocolresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

class DummyProtocolResult(gufe.ProtocolResult):
def get_estimate(self):
return self.data['estimate']
return self.data["estimate"]

def get_uncertainty(self):
return self.data['uncertainty']
return self.data["uncertainty"]


class TestProtocolResult(GufeTokenizableTestsMixin):
Expand All @@ -25,3 +25,33 @@ def instance(self):
estimate=4.2 * unit.kilojoule_per_mole,
uncertainty=0.2 * unit.kilojoule_per_mole,
)

def test_protocolresult_get_estimate(self, instance):
assert instance.get_estimate() == 4.2 * unit.kilojoule_per_mole

def test_protocolresult_get_uncertainty(self, instance):
assert instance.get_uncertainty() == 0.2 * unit.kilojoule_per_mole

def test_protocolresult_default_n_protocol_dag_results(self, instance):
assert instance.n_protocol_dag_results == 0

def test_protocol_result_from_dict_missing_n_protocol_dag_results(self, instance):
protocol_result_dict_form = instance.to_dict()
assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance
del protocol_result_dict_form['n_protocol_dag_results']
assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance

@pytest.mark.parametrize(
"arg, expected", [(0, 0), (1, 1), (-1, ValueError)]
)
def test_protocolresult_get_n_protocol_dag_results_args(self, arg, expected):
try:
protocol_result = DummyProtocolResult(
n_protocol_dag_results=arg,
estimate=4.2 * unit.kilojoule_per_mole,
uncertainty=0.2 * unit.kilojoule_per_mole,
)
assert protocol_result.n_protocol_dag_results == expected
except ValueError:
if expected is not ValueError:
raise AssertionError()
Loading