forked from airbus/scikit-decide
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a characteristic for solvers using action masks and make use of i…
…t in rollout (airbus#445) * Add a characteristic for solvers using action masks - Use it in rollout to make them be aware of current action mask. - Add a `get_action_mask()` method to domains by default converting applicable actions space into a 0-1 numpy array, provided that the action space of each agent is an EnumerableSpace. * Update how ray.rllib handles action masking - inherits from Maskable - do not require anymore FullObservable from the domain to use action masking, as get_action_mask() can be called without the solver knowing about the current state (and since in rollout, the actual domain is now used) - decide whether using action masking directly in __init__() so that using_applicable_actions() can be overriden properly - use common functions for unwrap_obs and wrap_action in solver and wrapper environment to avoid code duplication - use domain.get_action_mask() to convert applicable actions into a mask (the method is more efficient as not calling get_applicable_actions() for each actions) * Use np.int8 instead of np.int64 for action mask dtype This is more memory sufficient for only 0-1's. And seems to be the standard for action mask at least for ray.rllib, as shown in `action_mask_key` documentation at https://docs.ray.io/en/latest/rllib/rllib-training.html
- Loading branch information
Showing
8 changed files
with
573 additions
and
141 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Optional | ||
|
||
from skdecide import D, autocastable | ||
from skdecide.core import Mask | ||
|
||
if TYPE_CHECKING: | ||
# avoid circular import | ||
from skdecide import Domain | ||
|
||
|
||
__all__ = ["ApplicableActions", "Maskable"] | ||
|
||
|
||
class ApplicableActions: | ||
"""A solver must inherit this class if he can use information about applicable action. | ||
This characteristic will be checked during rollout so that `retrieve_applicable_actions()` will be called before | ||
each call to `step()`. For instance, this is the case for solvers using action masks (see `Maskable`). | ||
""" | ||
|
||
def using_applicable_actions(self): | ||
"""Tell if the solver is able to use applicable actions information. | ||
For instance, action masking could be possible only if | ||
considered domain action space is enumerable for each agent. | ||
The default implementation returns always True. | ||
""" | ||
return True | ||
|
||
def retrieve_applicable_actions(self, domain: Domain) -> None: | ||
"""Retrieve applicable actions and use it for future call to `self.step()`. | ||
To be called during rollout to get the actual applicable actions from the actual domain used in rollout. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class Maskable(ApplicableActions): | ||
"""A solver must inherit this class if he can use action masks to sample actions. | ||
For instance, it can be the case for wrappers around RL solvers like `sb3_contrib.MaskablePPO` or `ray.rllib` with | ||
custom model making use of action masking. | ||
An action mask is a format for specifying applicable actions when the action space is enumerable and finite. It is | ||
an array with 0's (for non-applicable actions) and 1's (for applicable actions). See `Events.get_action_mask()` for | ||
more information. | ||
""" | ||
|
||
_action_mask: Optional[D.T_agent[Mask]] = None | ||
|
||
def retrieve_applicable_actions(self, domain: Domain) -> None: | ||
"""Retrieve applicable actions and use it for future call to `self.step()`. | ||
To be called during rollout to get the actual applicable actions from the actual domain used in rollout. | ||
Transform applicable actions into an action_mask to be use when sampling action. | ||
""" | ||
self.set_action_mask(domain.get_action_mask()) | ||
|
||
@autocastable | ||
def set_action_mask(self, action_mask: Optional[D.T_agent[Mask]]) -> None: | ||
"""Set the action mask. | ||
To be called during rollout before `self.sample_action()`, assuming that | ||
`self.sample_action()` knows what to do with it. | ||
Autocastable so that it can use action_mask from original domain during rollout. | ||
""" | ||
self._set_action_mask(action_mask=action_mask) | ||
|
||
def _set_action_mask(self, action_mask: Optional[D.T_agent[Mask]]) -> None: | ||
"""Set the action mask. | ||
To be called during rollout before `self.sample_action()`, assuming that | ||
`self.sample_action()` knows what to do with it. | ||
""" | ||
|
||
self._action_mask = action_mask | ||
|
||
def get_action_mask(self) -> Optional[D.T_agent[Mask]]: | ||
"""Retrieve stored action masks. | ||
To be used by `self.sample_action()`. | ||
Returns None if `self.set_action_mask()` was not called. | ||
""" | ||
return self._action_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.