Skip to content

Commit

Permalink
Add option in rollout for sample_action kwargs (e.g. action masking)
Browse files Browse the repository at this point in the history
- We update the signature of sample_action(), get_next_action_distribution()
  and get_next_action() with **kwargs to allow some solver to use extra
  argument like action masks.
- We add `kwargs_sample_action_fn` arg to derive generic kwargs
  to pass to `sample_action()` during rollout
- We add `use_action_masking` flag as a shortcut to define appropriate
  `kwargs_sample_action_fn` using
  - domain.action_masks() is existing
  - `domain.is_applicable_action()` provided that `domain.get_action_space()` is a `skdecide.core.EnumerableSpace`
  • Loading branch information
nhuet committed Nov 29, 2024
1 parent e5e4a19 commit f8826b1
Show file tree
Hide file tree
Showing 32 changed files with 321 additions and 58 deletions.
2 changes: 1 addition & 1 deletion examples/gym_jsbsim_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _is_solution_defined_for(self, observation: D.T_agent[D.T_observation]) -> b
return False # for to recompute the best action at each step greedily

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
if not self._is_solution_defined_for(observation):
self._solve_from(observation)
Expand Down
2 changes: 1 addition & 1 deletion examples/gym_jsbsim_iw.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(
)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
state = GymDomainStateProxy(
state=normalize_and_round(observation._state), context=observation._context
Expand Down
2 changes: 1 addition & 1 deletion examples/gym_jsbsim_riw.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def __init__(
)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
if self._continuous_planning or not self._is_solution_defined_for(observation):
self._solve_from(observation)
Expand Down
28 changes: 15 additions & 13 deletions skdecide/builders/solver/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from __future__ import annotations

from typing import Any

from skdecide.core import D, Distribution, SingleValueDistribution, autocastable

__all__ = ["Policies", "UncertainPolicies", "DeterministicPolicies"]
Expand All @@ -14,7 +16,7 @@ class Policies:

@autocastable
def sample_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Sample an action for the given observation (from the solver's current policy).
Expand All @@ -24,10 +26,10 @@ def sample_action(
# Returns
The sampled action.
"""
return self._sample_action(observation)
return self._sample_action(observation, **kwargs)

def _sample_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Sample an action for the given observation (from the solver's current policy).
Expand Down Expand Up @@ -68,13 +70,13 @@ class UncertainPolicies(Policies):
explicitly) as part of the solving process."""

def _sample_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
return self._get_next_action_distribution(observation).sample()
return self._get_next_action_distribution(observation, **kwargs).sample()

@autocastable
def get_next_action_distribution(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> Distribution[D.T_agent[D.T_concurrency[D.T_event]]]:
"""Get the probabilistic distribution of next action for the given observation (from the solver's current
policy).
Expand All @@ -85,10 +87,10 @@ def get_next_action_distribution(
# Returns
The probabilistic distribution of next action.
"""
return self._get_next_action_distribution(observation)
return self._get_next_action_distribution(observation, **kwargs)

def _get_next_action_distribution(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> Distribution[D.T_agent[D.T_concurrency[D.T_event]]]:
"""Get the probabilistic distribution of next action for the given observation (from the solver's current
policy).
Expand All @@ -106,13 +108,13 @@ class DeterministicPolicies(UncertainPolicies):
"""A solver must inherit this class if it computes a deterministic policy as part of the solving process."""

def _get_next_action_distribution(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> Distribution[D.T_agent[D.T_concurrency[D.T_event]]]:
return SingleValueDistribution(self._get_next_action(observation))
return SingleValueDistribution(self._get_next_action(observation, **kwargs))

@autocastable
def get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the next deterministic action (from the solver's current policy).
Expand All @@ -122,10 +124,10 @@ def get_next_action(
# Returns
The next deterministic action.
"""
return self._get_next_action(observation)
return self._get_next_action(observation, **kwargs)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the next deterministic action (from the solver's current policy).
Expand Down
3 changes: 2 additions & 1 deletion skdecide/hub/solver/aostar/aostar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import sys
from collections.abc import Callable
from typing import Any

from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
FloatHyperparameter,
Expand Down Expand Up @@ -175,7 +176,7 @@ def _is_solution_defined_for(
return self._solver.is_solution_defined_for(observation)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the best computed action in terms of best Q-value in a given state.
The solver is run from `observation` if no solution is defined (i.e. has been
Expand Down
3 changes: 2 additions & 1 deletion skdecide/hub/solver/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

from collections.abc import Callable, Iterable
from typing import Any

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -259,7 +260,7 @@ def _solve(self) -> None:
print("Final Reward:", self.reward_evaluation, "Policy", self.policy)

def _sample_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:

# print('observation', observation, 'Policy', self.policy)
Expand Down
3 changes: 2 additions & 1 deletion skdecide/hub/solver/astar/astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import sys
from collections.abc import Callable
from typing import Any

from skdecide import Domain, Solver, hub
from skdecide.builders.domain import (
Expand Down Expand Up @@ -149,7 +150,7 @@ def _is_solution_defined_for(
return self._solver.is_solution_defined_for(observation)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the best computed action in terms of minimum cost-to-go in a given state.
The solver is run from `observation` if no solution is defined (i.e. has been
Expand Down
2 changes: 1 addition & 1 deletion skdecide/hub/solver/bfws/bfws.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _is_solution_defined_for(
return self._solver.is_solution_defined_for(observation)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the best computed action in terms of minimum cost-to-go in a given state.
The solver is run from `observation` if no solution is defined (i.e. has been
Expand Down
3 changes: 2 additions & 1 deletion skdecide/hub/solver/cgp/cgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

from collections.abc import Callable, Iterable
from typing import Any

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -348,7 +349,7 @@ def _solve(self) -> None:
es.run(self._n_it)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:

return denorm(
Expand Down
2 changes: 1 addition & 1 deletion skdecide/hub/solver/do_solver/do_solver_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def compute_external_policy(self, policy_method_params: PolicyMethodParams):
)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
return self.policy_object.get_next_action(observation=observation)

Expand Down
8 changes: 4 additions & 4 deletions skdecide/hub/solver/do_solver/gphh.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def _solve(self) -> None:
)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
action = self.policy.sample_action(observation)
# print('action_1: ', action.action)
Expand Down Expand Up @@ -853,7 +853,7 @@ def reset(self):
pass

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
run_sgs = True
cheat_mode = False
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def reset(self):
pass

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:

run_sgs = True
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def reset(self):
pass

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
run_sgs = True
cheat_mode = False
Expand Down
4 changes: 2 additions & 2 deletions skdecide/hub/solver/do_solver/sgs_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from enum import Enum
from functools import partial
from typing import Optional, Union
from typing import Any, Optional, Union

from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
EnumHyperparameter,
Expand Down Expand Up @@ -154,7 +154,7 @@ def build_function(self):
self.func = func

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
return self.func(state=observation)

Expand Down
3 changes: 2 additions & 1 deletion skdecide/hub/solver/ilaostar/ilaostar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import sys
from collections.abc import Callable
from typing import Any

from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
FloatHyperparameter,
Expand Down Expand Up @@ -167,7 +168,7 @@ def _is_solution_defined_for(
return self._solver.is_solution_defined_for(observation)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the best computed action in terms of of best Q-value in a given state.
The solver is run from `observation` if no solution is defined (i.e. has been
Expand Down
2 changes: 1 addition & 1 deletion skdecide/hub/solver/iw/iw.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _is_solution_defined_for(
return self._solver.is_solution_defined_for(observation)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the best computed action in terms of minimum cost-to-go in a given state.
The solver is run from `observation` if no solution is defined (i.e. has been
Expand Down
4 changes: 2 additions & 2 deletions skdecide/hub/solver/lazy_astar/lazy_astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Callable
from heapq import heappop, heappush
from itertools import count
from typing import Optional
from typing import Any, Optional

from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
FloatHyperparameter,
Expand Down Expand Up @@ -219,7 +219,7 @@ def extender(node, label, explored):
# return estim_total, path # TODO: find a way to expose these things through public API?

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
return self._policy[observation]

Expand Down
4 changes: 2 additions & 2 deletions skdecide/hub/solver/lrtastar/lrtastar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from collections.abc import Callable
from typing import Optional
from typing import Any, Optional

from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
FloatHyperparameter,
Expand Down Expand Up @@ -59,7 +59,7 @@ class LRTAstar(Solver, DeterministicPolicies, Utilities, FromAnyState):
]

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
return self._policy.get(observation, None)

Expand Down
4 changes: 2 additions & 2 deletions skdecide/hub/solver/lrtdp/lrtdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import sys
from collections.abc import Callable
from typing import Optional
from typing import Any, Optional

from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
CategoricalHyperparameter,
Expand Down Expand Up @@ -220,7 +220,7 @@ def _is_solution_defined_for(
return self._solver.is_solution_defined_for(observation)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the best computed action in terms of best Q-value in a given state. The search
subgraph which is no more reachable after executing the returned action is
Expand Down
4 changes: 2 additions & 2 deletions skdecide/hub/solver/mahd/mahd.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _solve_from(self, memory: D.T_memory[D.T_state]) -> None:
)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Gets the best computed joint action according to the higher-level heuristic
multi-agent solver in a given joint state.
Expand All @@ -169,7 +169,7 @@ def _get_next_action(
# Returns
D.T_agent[D.T_concurrency[D.T_event]]: Best computed joint action
"""
return self._multiagent_solver._get_next_action(observation)
return self._multiagent_solver._get_next_action(observation, **kwargs)

def _get_utility(self, observation: D.T_agent[D.T_observation]) -> D.T_value:
"""Gets the best value in a given joint state according to the higher-level
Expand Down
3 changes: 2 additions & 1 deletion skdecide/hub/solver/martdp/martdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import sys
from collections.abc import Callable
from typing import Any

from discrete_optimization.generic_tools.hyperparameters.hyperparameter import (
CategoricalHyperparameter,
Expand Down Expand Up @@ -204,7 +205,7 @@ def _is_solution_defined_for(
return self._solver.is_solution_defined_for(observation)

def _get_next_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:
"""Get the best computed joint action in terms of best Q-value in a given joint state.
The search subgraph which is no more reachable after executing the returned action is
Expand Down
3 changes: 2 additions & 1 deletion skdecide/hub/solver/maxent_irl/maxent_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
from collections.abc import Callable, Iterable
from typing import Any

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -270,7 +271,7 @@ def _solve(self) -> None:
)

def _sample_action(
self, observation: D.T_agent[D.T_observation]
self, observation: D.T_agent[D.T_observation], **kwargs: Any
) -> D.T_agent[D.T_concurrency[D.T_event]]:

state_idx = self.index_to_state(
Expand Down
Loading

0 comments on commit f8826b1

Please sign in to comment.