Skip to content

Commit

Permalink
feat: edit the policy to quarantine faulty tools
Browse files Browse the repository at this point in the history
  • Loading branch information
Adamantios committed Dec 20, 2024
1 parent 4221c57 commit 7e1005b
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def _try_recover_policy(self) -> Optional[EGreedyPolicy]:
def _get_init_policy(self) -> EGreedyPolicy:
"""Get the initial policy."""
# try to read the policy from the policy store, and if we cannot recover the policy, we create a new one
return self._try_recover_policy() or EGreedyPolicy(self.params.epsilon)
return self._try_recover_policy() or EGreedyPolicy(
self.params.epsilon,
self.params.policy_threshold,
self.params.tool_quarantine_duration,
)

def _fetch_accuracy_info(self) -> Generator[None, None, bool]:
"""Fetch the latest accuracy information available."""
Expand Down
111 changes: 90 additions & 21 deletions packages/valory/skills/decision_maker_abci/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import json
import random
from dataclasses import asdict, dataclass, field, is_dataclass
from typing import Any, Dict, List, Optional, Union
from time import time
from typing import Any, Dict, List, Optional, Tuple, Union

from packages.valory.skills.decision_maker_abci.utils.scaling import scale_value

Expand All @@ -44,8 +45,8 @@ def default(self, o: Any) -> Any:
return super().default(o)


def argmax(li: List) -> int:
"""Get the index of the max value within the provided list."""
def argmax(li: Union[Tuple, List]) -> int:
"""Get the index of the max value within the provided tuple or list."""
return li.index((max(li)))


Expand All @@ -61,6 +62,31 @@ class AccuracyInfo:
accuracy: float = 0.0


@dataclass
class ConsecutiveFailures:
"""The consecutive failures of a tool."""

n_failures: int = 0
timestamp: int = 0

def increase(self, timestamp: int) -> None:
"""Increase the number of consecutive failures."""
self.n_failures += 1
self.timestamp = timestamp

def reset(self, timestamp: int) -> None:
"""Reset the number of consecutive failures."""
self.n_failures = 0
self.timestamp = timestamp

def update_status(self, timestamp: int, has_failed: bool) -> None:
"""Update the number of consecutive failures."""
if has_failed:
self.increase(timestamp)
else:
self.reset(timestamp)


class EGreedyPolicyDecoder(json.JSONDecoder):
"""A custom JSON decoder for the e greedy policy."""

Expand All @@ -71,17 +97,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
@staticmethod
def hook(
data: Dict[str, Any]
) -> Union["EGreedyPolicy", AccuracyInfo, Dict[str, "EGreedyPolicy"]]:
) -> Union[
"EGreedyPolicy",
AccuracyInfo,
ConsecutiveFailures,
Dict[str, "EGreedyPolicy"],
Dict[str, ConsecutiveFailures],
]:
"""Perform the custom decoding."""
for cls_ in (AccuracyInfo, EGreedyPolicy):
for cls_ in (AccuracyInfo, ConsecutiveFailures, EGreedyPolicy):
cls_attributes = cls_.__annotations__.keys() # pylint: disable=no-member
if sorted(cls_attributes) == sorted(data.keys()) or (
cls_ == EGreedyPolicy
and sorted(cls_attributes - {"updated_ts"}) == sorted(data.keys())
):
# If EGreedyPolicy and 'updated_ts' is missing, set it to 0
if cls_ == EGreedyPolicy and "updated_ts" not in data:
data["updated_ts"] = 0
if sorted(cls_attributes) == sorted(data.keys()):
# if the attributes match the ones of the current class, use it to perform the deserialization
return cls_(**data)

Expand All @@ -93,8 +119,11 @@ class EGreedyPolicy:
"""An e-Greedy policy for the tool selection based on tool accuracy."""

eps: float
consecutive_failures_threshold: int
quarantine_duration: int
accuracy_store: Dict[str, AccuracyInfo] = field(default_factory=dict)
weighted_accuracy: Dict[str, float] = field(default_factory=dict)
consecutive_failures: Dict[str, ConsecutiveFailures] = field(default_factory=dict)
updated_ts: int = 0

def __post_init__(self) -> None:
Expand All @@ -111,7 +140,7 @@ def deserialize(cls, policy: str) -> "EGreedyPolicy":

@property
def tools(self) -> List[str]:
"""Get the number of the policy's tools."""
"""Get the policy's tools."""
return list(self.accuracy_store.keys())

@property
Expand All @@ -137,12 +166,45 @@ def random_tool(self) -> str:
"""Get the name of a tool randomly."""
return random.choice(list(self.accuracy_store.keys())) # nosec

def is_quarantined(self, tool: str) -> bool:
"""Check if the policy is valid."""
if tool not in self.consecutive_failures:
return False

failures = self.consecutive_failures[tool]
return (
failures.n_failures > self.consecutive_failures_threshold
and failures.timestamp + self.quarantine_duration > int(time())
)

@property
def valid_tools(self) -> List[str]:
"""Get the policy's tools."""
return list(
tool for tool in self.accuracy_store.keys() if not self.is_quarantined(tool)
)

@property
def valid_weighted_accuracy(self) -> Dict[str, float]:
"""Get the valid weighted accuracy."""
return {
tool: acc
for tool, acc in self.weighted_accuracy.items()
if not self.is_quarantined(tool)
}

@property
def best_tool(self) -> str:
"""Get the best tool."""
weighted_accuracy = list(self.weighted_accuracy.values())
best = argmax(weighted_accuracy)
return self.tools[best]
"""Get the best non-quarantined tool."""
valid_tools, valid_weighted_accuracies = zip(
*self.valid_weighted_accuracy.items()
)
if not valid_weighted_accuracies:
# if there are no unquarantined tools, then consider them all valid
valid_tools, valid_weighted_accuracies = self.weighted_accuracy.items()

best = argmax(valid_weighted_accuracies)
return valid_tools[best]

def update_weighted_accuracy(self) -> None:
"""Update the weighted accuracy for each tool."""
Expand Down Expand Up @@ -177,6 +239,12 @@ def tool_used(self, tool: str) -> None:
self.accuracy_store[tool].pending += 1
self.update_weighted_accuracy()

def tool_responded(self, tool: str, timestamp: int, failed: bool = True) -> None:
"""Update the policy based on the given tool's response."""
if tool not in self.consecutive_failures:
self.consecutive_failures[tool] = ConsecutiveFailures()
self.consecutive_failures[tool].update_status(timestamp, failed)

def update_accuracy_store(self, tool: str, winning: bool) -> None:
"""Update the accuracy store for the given tool."""
acc_info = self.accuracy_store[tool]
Expand All @@ -200,11 +268,12 @@ def stats_report(self) -> str:

report = "Policy statistics so far (only for resolved markets):\n"
stats = (
f"{tool} tool:\n"
f"\tTimes used: {self.accuracy_store[tool].requests}\n"
f"\tWeighted Accuracy: {self.weighted_accuracy[tool]}"
f"\t{tool} tool:\n"
f"\t\tQuarantined: {self.is_quarantined(tool)}\n"
f"\t\tTimes used: {self.accuracy_store[tool].requests}\n"
f"\t\tWeighted Accuracy: {self.weighted_accuracy[tool]}"
for tool in self.tools
)
report += "\n".join(stats)
report += f"\nBest tool so far is {self.best_tool!r}."
report += f"\nBest non-quarantined tool so far is {self.best_tool!r}."
return report
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def test_policy_property(
mock_policy_serialized = "serialized_policy_string"
mocked_db.get_strict.return_value = mock_policy_serialized

expected_policy = EGreedyPolicy(eps=0.1)
expected_policy = EGreedyPolicy(
eps=0.1, consecutive_failures_threshold=1, quarantine_duration=0
)
mock_deserialize.return_value = expected_policy

result = sync_data.policy
Expand Down Expand Up @@ -198,7 +200,10 @@ def test_weighted_accuracy(sync_data: SynchronizedData, mocked_db: MagicMock) ->
selected_mech_tool = "tool1"
policy_db_name = "policy"
policy_mock = EGreedyPolicy(
eps=0.1, accuracy_store={selected_mech_tool: AccuracyInfo(requests=1)}
eps=0.1,
consecutive_failures_threshold=1,
quarantine_duration=0,
accuracy_store={selected_mech_tool: AccuracyInfo(requests=1)},
).serialize()
mocked_db.get_strict = lambda name: (
policy_mock if name == policy_db_name else selected_mech_tool
Expand Down

0 comments on commit 7e1005b

Please sign in to comment.