Skip to content

Commit

Permalink
Add a characteristic for solvers using action masks and make use of i…
Browse files Browse the repository at this point in the history
…t in rollout (airbus#445)

* Add a characteristic for solvers using action masks

- Use it in rollout to make them be aware of current action mask.
- Add a `get_action_mask()` method to domains by default converting
  applicable actions space into a 0-1 numpy array, provided that the
  action space of each agent is an EnumerableSpace.

* Update how ray.rllib handles action masking

- inherits from Maskable
- do not require anymore FullObservable from the domain to use action
  masking, as get_action_mask() can be called without the solver knowing about
  the current state (and since in rollout, the actual domain is now
  used)
- decide whether using action masking directly in __init__() so that
  using_applicable_actions() can be overriden properly
- use common functions for unwrap_obs and wrap_action in solver and
  wrapper environment to avoid code duplication
- use domain.get_action_mask() to convert applicable actions into a mask
  (the method is more efficient as not calling get_applicable_actions()
  for each actions)

* Use np.int8 instead of np.int64 for action mask dtype

This is more memory sufficient for only 0-1's.
And seems to be the standard for action mask at least for ray.rllib,
as shown in `action_mask_key` documentation at
https://docs.ray.io/en/latest/rllib/rllib-training.html
  • Loading branch information
nhuet authored Dec 16, 2024
1 parent ab49ecb commit 491d3a1
Show file tree
Hide file tree
Showing 8 changed files with 573 additions and 141 deletions.
75 changes: 74 additions & 1 deletion skdecide/builders/domain/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import functools
from typing import Optional, Union

from skdecide.core import D, EmptySpace, Space, autocastable
import numpy as np

from skdecide.core import D, EmptySpace, EnumerableSpace, Mask, Space, autocastable

__all__ = ["Events", "Actions", "UnrestrictedActions"]

Expand Down Expand Up @@ -326,6 +328,77 @@ def _is_applicable_action_from(
else: # StrDict
return all(applicable_actions[k].contains(v) for k, v in action.items())

@autocastable
def get_action_mask(
self, memory: Optional[D.T_memory[D.T_state]] = None
) -> D.T_agent[Mask]:
"""Get action mask for the given memory or internal one if omitted.
An action mask is another (more specific) format for applicable actions, that has a meaning only if the action
space can be iterated over in some way. It is represented by a flat array of 0's and 1's ordered as the actions
when enumerated: 1 for an applicable action, and 0 for a not applicable action.
More precisely, this implementation makes the assumption that each agent action space is an `EnumerableSpace`,
and calls internally `self.get_applicable_action()`.
The action mask is used for instance by RL solvers to shut down logits associated to non-applicable actions in
the output of their internal neural network.
# Parameters
memory: The memory to consider. If None, works on the internal memory of the domain.
# Returns
a numpy array (or dict agent-> numpy array for multi-agent domains) with 0-1 indicating applicability of
the action (1 meaning applicable and 0 not applicable)
"""
return self._get_action_mask(memory=memory)

def _get_action_mask(
self, memory: Optional[D.T_memory[D.T_state]] = None
) -> D.T_agent[Mask]:
"""Get action mask for the given memory or internal one if omitted.
An action mask is another (more specific) format for applicable actions, that has a meaning only if the action
space can be iterated over in some way. It is represented by a flat array of 0's and 1's ordered as the actions
when enumerated: 1 for an applicable action, and 0 for a not applicable action.
More precisely, this implementation makes the assumption that each agent action space is an `EnumerableSpace`,
and calls internally `self.get_applicable_action()`.
The action mask is used for instance by RL solvers to shut down logits associated to non-applicable actions in
the output of their internal neural network.
# Parameters
memory: The memory to consider. If None, works on the internal memory of the domain.
# Returns
a numpy array (or dict agent-> numpy array for multi-agent domains) with 0-1 indicating applicability of
the action (1 meaning applicable and 0 not applicable)
"""
applicable_actions = self._get_applicable_actions(memory=memory)
action_space = self._get_action_space()
if self.T_agent == Union:
# single agent
return np.array(
[
1 if applicable_actions.contains(a) else 0
for a in action_space.get_elements()
],
dtype=np.int8,
)
else:
# multi agent
return {
agent: np.array(
[
1 if agent_applicable_actions.contains(a) else 0
for a in action_space[agent].get_elements()
],
dtype=np.int8,
)
for agent, agent_applicable_actions in applicable_actions.items()
}


class Actions(Events):
"""A domain must inherit this class if it handles only actions (i.e. controllable events)."""
Expand Down
1 change: 1 addition & 0 deletions skdecide/builders/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from skdecide.builders.solver.assessability import *
from skdecide.builders.solver.fromanystatesolvability import *
from skdecide.builders.solver.maskability import *
from skdecide.builders.solver.parallelability import *
from skdecide.builders.solver.policy import *
from skdecide.builders.solver.restorability import *
97 changes: 97 additions & 0 deletions skdecide/builders/solver/maskability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

from skdecide import D, autocastable
from skdecide.core import Mask

if TYPE_CHECKING:
# avoid circular import
from skdecide import Domain


__all__ = ["ApplicableActions", "Maskable"]


class ApplicableActions:
"""A solver must inherit this class if he can use information about applicable action.
This characteristic will be checked during rollout so that `retrieve_applicable_actions()` will be called before
each call to `step()`. For instance, this is the case for solvers using action masks (see `Maskable`).
"""

def using_applicable_actions(self):
"""Tell if the solver is able to use applicable actions information.
For instance, action masking could be possible only if
considered domain action space is enumerable for each agent.
The default implementation returns always True.
"""
return True

def retrieve_applicable_actions(self, domain: Domain) -> None:
"""Retrieve applicable actions and use it for future call to `self.step()`.
To be called during rollout to get the actual applicable actions from the actual domain used in rollout.
"""
raise NotImplementedError


class Maskable(ApplicableActions):
"""A solver must inherit this class if he can use action masks to sample actions.
For instance, it can be the case for wrappers around RL solvers like `sb3_contrib.MaskablePPO` or `ray.rllib` with
custom model making use of action masking.
An action mask is a format for specifying applicable actions when the action space is enumerable and finite. It is
an array with 0's (for non-applicable actions) and 1's (for applicable actions). See `Events.get_action_mask()` for
more information.
"""

_action_mask: Optional[D.T_agent[Mask]] = None

def retrieve_applicable_actions(self, domain: Domain) -> None:
"""Retrieve applicable actions and use it for future call to `self.step()`.
To be called during rollout to get the actual applicable actions from the actual domain used in rollout.
Transform applicable actions into an action_mask to be use when sampling action.
"""
self.set_action_mask(domain.get_action_mask())

@autocastable
def set_action_mask(self, action_mask: Optional[D.T_agent[Mask]]) -> None:
"""Set the action mask.
To be called during rollout before `self.sample_action()`, assuming that
`self.sample_action()` knows what to do with it.
Autocastable so that it can use action_mask from original domain during rollout.
"""
self._set_action_mask(action_mask=action_mask)

def _set_action_mask(self, action_mask: Optional[D.T_agent[Mask]]) -> None:
"""Set the action mask.
To be called during rollout before `self.sample_action()`, assuming that
`self.sample_action()` knows what to do with it.
"""

self._action_mask = action_mask

def get_action_mask(self) -> Optional[D.T_agent[Mask]]:
"""Retrieve stored action masks.
To be used by `self.sample_action()`.
Returns None if `self.set_action_mask()` was not called.
"""
return self._action_mask
10 changes: 10 additions & 0 deletions skdecide/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from dataclasses import asdict, astuple, dataclass, replace
from typing import Generic, Optional, TypeVar, Union

import numpy as np
import numpy.typing as npt

__all__ = [
"T",
"D",
Expand Down Expand Up @@ -666,6 +669,13 @@ def cast_evaluate_function(memory, action, next_state):
)


# The following alias is needed in core module so that autocast works:
# - `autocast` does not like "." after strings other than "D",
# - `autocast` needs types in annotations to be evaluable in `skdecide.core` namespace.
Mask = npt.NDArray[np.int8]
"""Alias for single agent action mask."""


SINGLE_AGENT_ID = "agent"

# (auto)cast-related objects/functions
Expand Down
Loading

0 comments on commit 491d3a1

Please sign in to comment.