From bab5f38486fbc4661ef3684ca8ccd200acd09e30 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 30 Oct 2024 15:02:11 -0700 Subject: [PATCH 1/7] Added the n_protocol_dag_results property to ProtocolResult The `gather` aggregation method removes too much information when reducing the provided protocol_dag_results. This change adds the new n_protocol_dag_results property that reflects the number of ProtocolDAGResult objects that were used to create the ProtocolResult. --- gufe/protocols/protocol.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index bfda10c8..91e3f7b6 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -6,7 +6,7 @@ """ import abc -from typing import Optional, Iterable, Any, Union +from typing import Optional, Iterable, Any, Union, Sized from openff.units import Quantity import warnings @@ -32,8 +32,11 @@ class ProtocolResult(GufeTokenizable): - `get_uncertainty` """ - def __init__(self, **data): + def __init__(self, n_protocol_dag_results: Optional[int] = None, **data): self._data = data + self._n_protocol_dag_results = ( + n_protocol_dag_results if n_protocol_dag_results is not None else 0 + ) @classmethod def _defaults(cls): @@ -46,6 +49,10 @@ def _to_dict(self): def _from_dict(cls, dct: dict): return cls(**dct['data']) + @property + def n_protocol_dag_results(self) -> int: + return self._n_protocol_dag_results + @property def data(self) -> dict[str, Any]: """ @@ -254,7 +261,14 @@ def gather( ProtocolResult Aggregated results from many `ProtocolDAGResult`s from a given `Protocol`. """ - return self.result_cls(**self._gather(protocol_dag_results)) + # Iterable does not implement __len__ and makes no guarantees that + # protocol_dag_results is finite, checking both in method signature + # doesn't appear possible, explicitly check for __len__ through the + # Sized type + if not isinstance(protcol_dag_results, Sized): + raise ValueError("`protocol_dag_results` must implement `__len__`") + return self.result_cls(n_protocol_dag_results=len(protocol_dag_results), + **self._gather(protocol_dag_results)) @abc.abstractmethod def _gather( From d97cf22dcd21f702fe0d46c4eed02db05c8ceb06 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 30 Oct 2024 15:25:01 -0700 Subject: [PATCH 2/7] Modified test and fixed typo --- gufe/protocols/protocol.py | 2 +- gufe/tests/test_protocol.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index 91e3f7b6..bda715d1 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -265,7 +265,7 @@ def gather( # protocol_dag_results is finite, checking both in method signature # doesn't appear possible, explicitly check for __len__ through the # Sized type - if not isinstance(protcol_dag_results, Sized): + if not isinstance(protocol_dag_results, Sized): raise ValueError("`protocol_dag_results` must implement `__len__`") return self.result_cls(n_protocol_dag_results=len(protocol_dag_results), **self._gather(protocol_dag_results)) diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index 8795e665..f558e20f 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -328,6 +328,7 @@ def test_create_execute_gather(self, protocol_dag): # gather aggregated results of interest protocolresult = protocol.gather([dagresult]) + assert protocolresult.n_protocol_dag_results == 1 assert len(protocolresult.data['logs']) == 1 assert len(protocolresult.data['logs'][0]) == 21 + 1 From 2aa451f72ddcc3ece4270dc93ca389fba0314edb Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 30 Oct 2024 15:55:37 -0700 Subject: [PATCH 3/7] Added a test for the infinite generator guardrail --- gufe/tests/test_protocol.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index f558e20f..c343ef93 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -3,7 +3,7 @@ import datetime import itertools from openff.units import unit -from typing import Optional, Iterable, List, Dict, Any, Union +from typing import Optional, Iterable, List, Dict, Any, Union, Sized from collections import defaultdict import pathlib @@ -334,6 +334,25 @@ def test_create_execute_gather(self, protocol_dag): assert protocolresult.get_estimate() == 95500.0 + def test_gather_infinite_iterable_guardrail(self, protocol_dag): + protocol, dag, dagresult = protocol_dag + + assert dagresult.ok() + + # we want an infinite generator, but one that would actually stop early in case + # the guardrail doesn't work, but the type system doesn't know that + def infinite_generator(): + while True: + yield dag + break + + gen = infinite_generator() + assert isinstance(gen, Iterable) + assert not isinstance(gen, Sized) + + with pytest.raises(ValueError, match="`protocol_dag_results` must implement `__len__`"): + protocol.gather(infinite_generator()) + def test_deprecation_warning_on_dict_mapping(self, instance, vacuum_ligand, solvated_ligand): lig = solvated_ligand.components['ligand'] mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={}) From 3b5ca83ac730d5b68a0406a2830b6f2f979864a4 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 31 Oct 2024 13:51:32 -0700 Subject: [PATCH 4/7] Replaced Optional[int] (default None) with int (default 0) `_to_dict` and `_from_dict` of `ProtocolResult` were also updated to account for the new n_protocol_dag_results property --- gufe/protocols/protocol.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index bda715d1..bf44e7fa 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -10,7 +10,7 @@ from openff.units import Quantity import warnings -from ..settings import Settings, SettingsBaseModel +from ..settings import Settings from ..tokenization import GufeTokenizable, GufeKey from ..chemicalsystem import ChemicalSystem from ..mapping import ComponentMapping @@ -18,7 +18,6 @@ from .protocoldag import ProtocolDAG, ProtocolDAGResult from .protocolunit import ProtocolUnit - class ProtocolResult(GufeTokenizable): """ Container for all results for a single :class:`Transformation`. @@ -32,27 +31,29 @@ class ProtocolResult(GufeTokenizable): - `get_uncertainty` """ - def __init__(self, n_protocol_dag_results: Optional[int] = None, **data): + def __init__(self, n_protocol_dag_results: int = 0, **data): self._data = data - self._n_protocol_dag_results = ( - n_protocol_dag_results if n_protocol_dag_results is not None else 0 - ) + + if not n_protocol_dag_results >= 0: + raise ValueError("`n_protocol_dag_results` must be an integer greater than or equal to zero") + + self._n_protocol_dag_results = n_protocol_dag_results @classmethod def _defaults(cls): return {} def _to_dict(self): - return {'data': self.data} + return {'n_protocol_dag_results': self.n_protocol_dag_results, 'data': self.data} @classmethod def _from_dict(cls, dct: dict): - return cls(**dct['data']) + return cls(n_protocol_dag_results=dct['n_protocol_dag_results'], **dct['data']) @property def n_protocol_dag_results(self) -> int: return self._n_protocol_dag_results - + @property def data(self) -> dict[str, Any]: """ From 1c8bf8cc4ed03bef6a912181bba492f8394c6899 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 31 Oct 2024 13:58:45 -0700 Subject: [PATCH 5/7] Added more tests for the ProtocolResult Explicitly test `get_uncertainty`, `get_estimate`, and `n_protocol_dag_results` --- gufe/tests/test_protocolresult.py | 32 +++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/gufe/tests/test_protocolresult.py b/gufe/tests/test_protocolresult.py index 03833f75..9121b72f 100644 --- a/gufe/tests/test_protocolresult.py +++ b/gufe/tests/test_protocolresult.py @@ -9,16 +9,16 @@ class DummyProtocolResult(gufe.ProtocolResult): def get_estimate(self): - return self.data['estimate'] + return self.data["estimate"] def get_uncertainty(self): - return self.data['uncertainty'] + return self.data["uncertainty"] class TestProtocolResult(GufeTokenizableTestsMixin): cls = DummyProtocolResult - key = 'DummyProtocolResult-b7b854b39c1e37feabec58b4000680a0' - repr = f'<{key}>' + key = "DummyProtocolResult-38c5649090dae613b2570c28155cc5fa" + repr = f"<{key}>" @pytest.fixture def instance(self): @@ -26,3 +26,27 @@ def instance(self): estimate=4.2 * unit.kilojoule_per_mole, uncertainty=0.2 * unit.kilojoule_per_mole, ) + + def test_protocolresult_get_estimate(self, instance): + assert instance.get_estimate() == 4.2 * unit.kilojoule_per_mole + + def test_protocolresult_get_uncertainty(self, instance): + assert instance.get_uncertainty() == 0.2 * unit.kilojoule_per_mole + + def test_protocolresult_default_n_protocol_dag_results(self, instance): + assert instance.n_protocol_dag_results == 0 + + @pytest.mark.parametrize( + "arg, expected", [(0, 0), (1, 1), (-1, ValueError)] + ) + def test_protocolresult_get_n_protocol_dag_results_args(self, arg, expected): + try: + protocol_result = DummyProtocolResult( + n_protocol_dag_results=arg, + estimate=4.2 * unit.kilojoule_per_mole, + uncertainty=0.2 * unit.kilojoule_per_mole, + ) + assert protocol_result.n_protocol_dag_results == expected + except ValueError: + if expected is not ValueError: + raise AssertionError() From 8fe3e24b2a1b205c972098570b90b184af297548 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 20 Nov 2024 20:38:21 -0700 Subject: [PATCH 6/7] Refactor ProtocolResult._from_dict First attempt to extract n_protocol_dag_results from the data dict. If a KeyError is raised, as you'd find for ProtocolResults serialized prior to this commit, set n_protocol_dag_results to the default value of 0. Co-authored-by: Irfan Alibay --- gufe/protocols/protocol.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index bf44e7fa..39816bb7 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -48,7 +48,12 @@ def _to_dict(self): @classmethod def _from_dict(cls, dct: dict): - return cls(n_protocol_dag_results=dct['n_protocol_dag_results'], **dct['data']) + # TODO: remove in gufe 2.0 + try: + n_protocol_dag_results = dct['n_protocol_dag_results'] + except KeyError: + n_protocol_dag_results = 0 + return cls(n_protocol_dag_results=n_protocol_dag_results, **dct['data']) @property def n_protocol_dag_results(self) -> int: From 8d7f3b7f3c11484090ac5710af4ef1ce11e67eb3 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 20 Nov 2024 20:50:55 -0700 Subject: [PATCH 7/7] Add test for creating a ProtocolResult from dict with missing data Test creating a ProtocolResult from a dictionary that is missing the n_protocol_dag_results key. --- gufe/tests/test_protocolresult.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gufe/tests/test_protocolresult.py b/gufe/tests/test_protocolresult.py index a3ccddb3..61f0c292 100644 --- a/gufe/tests/test_protocolresult.py +++ b/gufe/tests/test_protocolresult.py @@ -35,6 +35,12 @@ def test_protocolresult_get_uncertainty(self, instance): def test_protocolresult_default_n_protocol_dag_results(self, instance): assert instance.n_protocol_dag_results == 0 + def test_protocol_result_from_dict_missing_n_protocol_dag_results(self, instance): + protocol_result_dict_form = instance.to_dict() + assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance + del protocol_result_dict_form['n_protocol_dag_results'] + assert DummyProtocolResult.from_dict(protocol_result_dict_form) == instance + @pytest.mark.parametrize( "arg, expected", [(0, 0), (1, 1), (-1, ValueError)] )