forked from airbus/scikit-decide
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a sb3 algo + policy for domains with graph observations (airbus#441)
- we reuse our stable_baselines3 wrapper - the policy is extracting features from the graph with a GNN - the GNN is using pytorch-geometric - We subclass - ActorCriticPolicy: - feature extractor = gnn - custom conversion of observation to torch to convert into torch_geometric.data.Data - PPO to handle properly - observation conversion - rollout buffer - Current limitations: - we extract a fixed number of features (independent of edge/node numbers) for now as we end with a feature reduction layer connected to a classic mlp (not knowning anything about the current graph structure) - User input: the user can define (and default choices are made else) - the gnn (default to a 2 layers GCN), taking as inputs w.r.t torch_geometric conventions: - x: nodes features - edge_index: edge indices or sparse transposed adjency matrix - edge_attr (optional): edges features - edge_weight (optional): edge weights (taken from first dimension of edge_attr) - the feature reduction layer from the gnn output to the fixed number of features (default to global_max_pool + linear layer + relu) We also introduce a multiinput policy to take into account static graph features. The observation space is a DictSpace whose subspaces can contain some Graph spaces.
- Loading branch information
Showing
20 changed files
with
2,471 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv | ||
from gymnasium.spaces import Box, Graph, GraphInstance | ||
|
||
from skdecide.core import Space, TransitionOutcome, Value | ||
from skdecide.domains import Domain | ||
from skdecide.hub.domain.gym import GymDomain | ||
from skdecide.hub.solver.stable_baselines import StableBaseline | ||
from skdecide.hub.solver.stable_baselines.gnn import GraphPPO | ||
from skdecide.hub.space.gym import GymSpace, ListSpace | ||
from skdecide.utils import rollout | ||
|
||
# JSP graph env | ||
|
||
|
||
class D(Domain): | ||
T_state = GraphInstance # Type of states | ||
T_observation = T_state # Type of observations | ||
T_event = int # Type of events | ||
T_value = float # Type of transition values (rewards or costs) | ||
T_info = None # Type of additional information in environment outcome | ||
|
||
|
||
class GraphJspDomain(GymDomain, D): | ||
_gym_env: DisjunctiveGraphJspEnv | ||
|
||
def __init__(self, gym_env): | ||
GymDomain.__init__(self, gym_env=gym_env) | ||
if self._gym_env.normalize_observation_space: | ||
self.n_nodes_features = gym_env.n_machines + 1 | ||
else: | ||
self.n_nodes_features = 2 | ||
|
||
def _state_step( | ||
self, action: D.T_event | ||
) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]: | ||
outcome = super()._state_step(action=action) | ||
outcome.state = self._np_state2graph_state(outcome.state) | ||
return outcome | ||
|
||
def _get_applicable_actions_from( | ||
self, memory: D.T_memory[D.T_state] | ||
) -> D.T_agent[Space[D.T_event]]: | ||
return ListSpace(np.nonzero(self._gym_env.valid_action_mask())[0]) | ||
|
||
def _is_applicable_action_from( | ||
self, action: D.T_agent[D.T_event], memory: D.T_memory[D.T_state] | ||
) -> bool: | ||
return self._gym_env.valid_action_mask()[action] | ||
|
||
def _state_reset(self) -> D.T_state: | ||
return self._np_state2graph_state(super()._state_reset()) | ||
|
||
def _get_observation_space_(self) -> Space[D.T_observation]: | ||
if self._gym_env.normalize_observation_space: | ||
original_graph_space = Graph( | ||
node_space=Box( | ||
low=0.0, high=1.0, shape=(self.n_nodes_features,), dtype=np.float_ | ||
), | ||
edge_space=Box(low=0, high=1.0, dtype=np.float_), | ||
) | ||
|
||
else: | ||
original_graph_space = Graph( | ||
node_space=Box( | ||
low=np.array([0, 0]), | ||
high=np.array( | ||
[ | ||
self._gym_env.n_machines, | ||
self._gym_env.longest_processing_time, | ||
] | ||
), | ||
dtype=np.int_, | ||
), | ||
edge_space=Box( | ||
low=0, high=self._gym_env.longest_processing_time, dtype=np.int_ | ||
), | ||
) | ||
return GymSpace(original_graph_space) | ||
|
||
def _np_state2graph_state(self, np_state: np.array) -> GraphInstance: | ||
if not self._gym_env.normalize_observation_space: | ||
np_state = np_state.astype(np.int_) | ||
|
||
nodes = np_state[:, -self.n_nodes_features :] | ||
adj = np_state[:, : -self.n_nodes_features] | ||
edge_starts_ends = adj.nonzero() | ||
edge_links = np.transpose(edge_starts_ends) | ||
edges = adj[edge_starts_ends][:, None] | ||
|
||
return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links) | ||
|
||
def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: | ||
return self._gym_env.render(**kwargs) | ||
|
||
|
||
jsp = np.array( | ||
[ | ||
[ | ||
[0, 1, 2], # machines for job 0 | ||
[0, 2, 1], # machines for job 1 | ||
[0, 1, 2], # machines for job 2 | ||
], | ||
[ | ||
[3, 2, 2], # task durations of job 0 | ||
[2, 1, 4], # task durations of job 1 | ||
[0, 4, 3], # task durations of job 2 | ||
], | ||
] | ||
) | ||
|
||
|
||
jsp_env = DisjunctiveGraphJspEnv( | ||
jps_instance=jsp, | ||
perform_left_shift_if_possible=True, | ||
normalize_observation_space=False, | ||
flat_observation_space=False, | ||
action_mode="task", | ||
) | ||
|
||
# random rollout | ||
domain = GraphJspDomain(gym_env=jsp_env) | ||
rollout(domain=domain, max_steps=jsp_env.total_tasks_without_dummies, num_episodes=1) | ||
|
||
# solve with sb3-GraphPPO | ||
domain_factory = lambda: GraphJspDomain(gym_env=jsp_env) | ||
with StableBaseline( | ||
domain_factory=domain_factory, | ||
algo_class=GraphPPO, | ||
baselines_policy="GraphInputPolicy", | ||
learn_config={"total_timesteps": 100}, | ||
) as solver: | ||
|
||
solver.solve() | ||
rollout(domain=domain_factory(), solver=solver, max_steps=100, num_episodes=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
from typing import Any, Optional | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
from gymnasium.spaces import Box, Discrete, Graph, GraphInstance | ||
|
||
from skdecide.builders.domain import Renderable, UnrestrictedActions | ||
from skdecide.core import Space, Value | ||
from skdecide.domains import DeterministicPlanningDomain | ||
from skdecide.hub.domain.maze import Maze | ||
from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State | ||
from skdecide.hub.solver.stable_baselines import StableBaseline | ||
from skdecide.hub.solver.stable_baselines.gnn import GraphPPO | ||
from skdecide.hub.space.gym import GymSpace, ListSpace | ||
from skdecide.utils import rollout | ||
|
||
|
||
class D(DeterministicPlanningDomain, UnrestrictedActions, Renderable): | ||
T_state = GraphInstance # Type of states | ||
T_observation = T_state # Type of observations | ||
T_event = Action # Type of events | ||
T_value = float # Type of transition values (rewards or costs) | ||
T_predicate = bool # Type of logical checks | ||
T_info = ( | ||
None # Type of additional information given as part of an environment outcome | ||
) | ||
|
||
|
||
class GraphMaze(D): | ||
def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): | ||
self.discrete_features = discrete_features | ||
self.maze_domain = Maze(maze_str=maze_str) | ||
np_wall = np.array(self.maze_domain._maze) | ||
np_y = np.array( | ||
[ | ||
[(i) for j in range(self.maze_domain._num_cols)] | ||
for i in range(self.maze_domain._num_rows) | ||
] | ||
) | ||
np_x = np.array( | ||
[ | ||
[(j) for j in range(self.maze_domain._num_cols)] | ||
for i in range(self.maze_domain._num_rows) | ||
] | ||
) | ||
walls = np_wall.ravel() | ||
coords = [i for i in zip(np_y.ravel(), np_x.ravel())] | ||
np_node_id = np.reshape(range(len(walls)), np_wall.shape) | ||
edge_links = [] | ||
edges = [] | ||
for i in range(self.maze_domain._num_rows): | ||
for j in range(self.maze_domain._num_cols): | ||
current_coord = (i, j) | ||
if i > 0: | ||
next_coord = (i - 1, j) | ||
edge_links.append( | ||
(np_node_id[current_coord], np_node_id[next_coord]) | ||
) | ||
edges.append(np_wall[current_coord] * np_wall[next_coord]) | ||
if i < self.maze_domain._num_rows - 1: | ||
next_coord = (i + 1, j) | ||
edge_links.append( | ||
(np_node_id[current_coord], np_node_id[next_coord]) | ||
) | ||
edges.append(np_wall[current_coord] * np_wall[next_coord]) | ||
if j > 0: | ||
next_coord = (i, j - 1) | ||
edge_links.append( | ||
(np_node_id[current_coord], np_node_id[next_coord]) | ||
) | ||
edges.append(np_wall[current_coord] * np_wall[next_coord]) | ||
if j < self.maze_domain._num_cols - 1: | ||
next_coord = (i, j + 1) | ||
edge_links.append( | ||
(np_node_id[current_coord], np_node_id[next_coord]) | ||
) | ||
edges.append(np_wall[current_coord] * np_wall[next_coord]) | ||
self.edges = np.array(edges) | ||
self.edge_links = np.array(edge_links) | ||
self.walls = walls | ||
self.node_ids = np_node_id | ||
self.coords = coords | ||
|
||
def _mazestate2graph(self, state: State) -> GraphInstance: | ||
x, y = state | ||
agent_presence = np.zeros(self.walls.shape, dtype=self.walls.dtype) | ||
agent_presence[self.node_ids[y, x]] = 1 | ||
nodes = np.stack([self.walls, agent_presence], axis=-1) | ||
if self.discrete_features: | ||
return GraphInstance( | ||
nodes=nodes, edges=self.edges, edge_links=self.edge_links | ||
) | ||
else: | ||
return GraphInstance( | ||
nodes=nodes, edges=self.edges[:, None], edge_links=self.edge_links | ||
) | ||
|
||
def _graph2mazestate(self, graph: GraphInstance) -> State: | ||
y, x = self.coords[graph.nodes[:, 1].nonzero()[0][0]] | ||
return State(x=x, y=y) | ||
|
||
def _is_terminal(self, state: D.T_state) -> D.T_predicate: | ||
return self.maze_domain._is_terminal(self._graph2mazestate(state)) | ||
|
||
def _get_next_state(self, memory: D.T_state, action: D.T_event) -> D.T_state: | ||
maze_memory = self._graph2mazestate(memory) | ||
maze_next_state = self.maze_domain._get_next_state( | ||
memory=maze_memory, action=action | ||
) | ||
return self._mazestate2graph(maze_next_state) | ||
|
||
def _get_transition_value( | ||
self, | ||
memory: D.T_state, | ||
action: D.T_event, | ||
next_state: Optional[D.T_state] = None, | ||
) -> Value[D.T_value]: | ||
maze_memory = self._graph2mazestate(memory) | ||
if next_state is None: | ||
maze_next_state = None | ||
else: | ||
maze_next_state = self._graph2mazestate(next_state) | ||
return self.maze_domain._get_transition_value( | ||
memory=maze_memory, action=action, next_state=maze_next_state | ||
) | ||
|
||
def _get_action_space_(self) -> Space[D.T_event]: | ||
return self.maze_domain._get_action_space_() | ||
|
||
def _get_goals_(self) -> Space[D.T_observation]: | ||
return ListSpace([self._mazestate2graph(self.maze_domain._goal)]) | ||
|
||
def _is_goal( | ||
self, observation: D.T_agent[D.T_observation] | ||
) -> D.T_agent[D.T_predicate]: | ||
return self.maze_domain._is_goal(self._graph2mazestate(observation)) | ||
|
||
def _get_initial_state_(self) -> D.T_state: | ||
return self._mazestate2graph(self.maze_domain._get_initial_state_()) | ||
|
||
def _get_observation_space_(self) -> Space[D.T_observation]: | ||
if self.discrete_features: | ||
return GymSpace( | ||
Graph( | ||
node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), | ||
edge_space=Discrete(2), | ||
) | ||
) | ||
else: | ||
return GymSpace( | ||
Graph( | ||
node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), | ||
edge_space=Box(low=0, high=1, shape=(1,), dtype=self.edges.dtype), | ||
) | ||
) | ||
|
||
def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any: | ||
maze_memory = self._graph2mazestate(memory) | ||
self.maze_domain._render_from(memory=maze_memory, **kwargs) | ||
|
||
|
||
MAZE = """ | ||
+-+-+-+-+o+-+-+--+-+-+ | ||
| | | | | ||
+ + + +-+-+-+ +--+ + + | ||
| | | | | | | | | ||
+ +-+-+ +-+ + + -+ +-+ | ||
| | | | | | | | ||
+ + + + + + + +--+ +-+ | ||
| | | | | | | ||
+-+-+-+-+-+-+-+ -+-+ + | ||
| | | | | ||
+ +-+-+-+-+ + +--+-+ + | ||
| | | | | ||
+ + + +-+ +-+ +--+-+-+ | ||
| | | | | | | ||
+ +-+-+ + +-+ + -+-+ + | ||
| | | | | | | | | ||
+-+ +-+ + + + +--+ + + | ||
| | | | | | | | ||
+ +-+ +-+-+-+-+ -+ + + | ||
| | | | | | ||
+-+-+-+-+-+x+-+--+-+-+ | ||
""" | ||
|
||
domain = GraphMaze(maze_str=MAZE, discrete_features=True) | ||
assert domain.reset() in domain.get_observation_space() | ||
|
||
# random rollout | ||
rollout(domain=domain, max_steps=50, num_episodes=1) | ||
|
||
# solve with sb3-PPO-GNN | ||
domain_factory = lambda: GraphMaze(maze_str=MAZE) | ||
max_steps = domain.maze_domain._num_cols * domain.maze_domain._num_rows | ||
with StableBaseline( | ||
domain_factory=domain_factory, | ||
algo_class=GraphPPO, | ||
baselines_policy="GraphInputPolicy", | ||
learn_config={"total_timesteps": 100}, | ||
) as solver: | ||
|
||
solver.solve() | ||
rollout(domain=domain_factory(), solver=solver, max_steps=max_steps, num_episodes=1) |
Oops, something went wrong.