Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the n_protocol_dag_results property to ProtocolResult #381

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions gufe/protocols/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

import abc
from typing import Optional, Iterable, Any, Union
from typing import Optional, Iterable, Any, Union, Sized
from openff.units import Quantity
import warnings

Expand All @@ -32,8 +32,11 @@ class ProtocolResult(GufeTokenizable):
- `get_uncertainty`
"""

def __init__(self, **data):
def __init__(self, n_protocol_dag_results: Optional[int] = None, **data):
self._data = data
self._n_protocol_dag_results = (
n_protocol_dag_results if n_protocol_dag_results is not None else 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to wrap my head around the use case here, could you explain the case where None happens? I.e. at first glance setting None to zero seems incorrect if None means that the gather method did not pass through a value. In that case None means undefined and it's much cleaner than getting 0 with the idea that you got no results (but you do).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the context of a Protocol's gather, you'll never see a None passed in. I can see instances where ProtocolResults are used directly, incorrectly or not, and a None would make sense. But of course I immediately throw away the None in the code you highlighted above. I suppose there is still some indecision on my part what the right signature should be and which values the ProtocolResult holds on to. I'm tempted to just drop Optional[int] and use int with default 0.

Just from the tests, I see instances where someone would make a ProtocolResult directly and I'll probably do something similar in testing parts of stratocaster downstream:

class DummyProtocolResult(gufe.ProtocolResult):
    def get_estimate(self):
        return self.data['estimate']

    def get_uncertainty(self):
        return self.data['uncertainty']


class TestProtocolResult(GufeTokenizableTestsMixin):
    cls = DummyProtocolResult
    key = 'DummyProtocolResult-b7b854b39c1e37feabec58b4000680a0'
    repr = f'<{key}>'

    @pytest.fixture
    def instance(self):
        return DummyProtocolResult(
            estimate=4.2 * unit.kilojoule_per_mole,
            uncertainty=0.2 * unit.kilojoule_per_mole,
        )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ianmkenney - is this a time critical thing? I would like to spend a bit of time to check what this means for the existing Protocols, but I don't think I'll have time until Monday :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily. I'll be testing stratocaster strategies based on this branch until it's merged in the hopes that we get some solution that's close in spirit to what we have here.

)

@classmethod
def _defaults(cls):
Expand All @@ -46,6 +49,10 @@ def _to_dict(self):
def _from_dict(cls, dct: dict):
return cls(**dct['data'])

@property
def n_protocol_dag_results(self) -> int:
return self._n_protocol_dag_results

@property
def data(self) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -254,7 +261,14 @@ def gather(
ProtocolResult
Aggregated results from many `ProtocolDAGResult`s from a given `Protocol`.
"""
return self.result_cls(**self._gather(protocol_dag_results))
# Iterable does not implement __len__ and makes no guarantees that
# protocol_dag_results is finite, checking both in method signature
# doesn't appear possible, explicitly check for __len__ through the
# Sized type
if not isinstance(protocol_dag_results, Sized):
raise ValueError("`protocol_dag_results` must implement `__len__`")
return self.result_cls(n_protocol_dag_results=len(protocol_dag_results),
**self._gather(protocol_dag_results))

@abc.abstractmethod
def _gather(
Expand Down
22 changes: 21 additions & 1 deletion gufe/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
import itertools
from openff.units import unit
from typing import Optional, Iterable, List, Dict, Any, Union
from typing import Optional, Iterable, List, Dict, Any, Union, Sized
from collections import defaultdict
import pathlib

Expand Down Expand Up @@ -328,11 +328,31 @@ def test_create_execute_gather(self, protocol_dag):
# gather aggregated results of interest
protocolresult = protocol.gather([dagresult])

assert protocolresult.n_protocol_dag_results == 1
assert len(protocolresult.data['logs']) == 1
assert len(protocolresult.data['logs'][0]) == 21 + 1

assert protocolresult.get_estimate() == 95500.0

def test_gather_infinite_iterable_guardrail(self, protocol_dag):
protocol, dag, dagresult = protocol_dag

assert dagresult.ok()

# we want an infinite generator, but one that would actually stop early in case
# the guardrail doesn't work, but the type system doesn't know that
def infinite_generator():
while True:
yield dag
break

gen = infinite_generator()
assert isinstance(gen, Iterable)
assert not isinstance(gen, Sized)

with pytest.raises(ValueError, match="`protocol_dag_results` must implement `__len__`"):
protocol.gather(infinite_generator())

def test_deprecation_warning_on_dict_mapping(self, instance, vacuum_ligand, solvated_ligand):
lig = solvated_ligand.components['ligand']
mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={})
Expand Down
Loading