diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c387a269c..a956d1201c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -564,9 +564,9 @@ jobs: python_version=${{ matrix.python-version }} wheelfile=$(ls ./wheels/scikit_decide*-cp${python_version/\./}-*win*.whl) if [ "$python_version" = "3.12" ]; then - pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17" + pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17" graph-jsp-env pytest-cases else - pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna + pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna graph-jsp-env pytest-cases fi - name: Test with pytest @@ -575,6 +575,8 @@ jobs: ${{ env.minizinc_config_cmdline }} # test minizinc python -c "import minizinc; print(minizinc.default_driver.minizinc_version); minizinc.Solver.lookup('gecode')" + # Set encoding to avoid issue with windows + graph-jsp-env (cf https://github.com/Alexander-Nasuta/graph-jsp-env/issues/3) + export PYTHONIOENCODING=UTF-8 # run pytest # we split tests using # - c++ scikit-decide library @@ -662,9 +664,9 @@ jobs: arch=$(uname -m) wheelfile=$(ls ./wheels/scikit_decide*-cp${python_version/\./}-*macos*${arch}.whl) if [ "$python_version" = "3.12" ]; then - pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17" + pip install ${wheelfile}[all] pytest "pygame>=2.5" optuna "cffi>=1.17" graph-jsp-env pytest-cases else - pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna + pip install ${wheelfile}[all] pytest gymnasium[classic-control] optuna graph-jsp-env pytest-cases fi - name: Test with pytest @@ -762,9 +764,9 @@ jobs: python_version=${{ matrix.python-version }} wheelfile=$(ls ./wheels/scikit_decide*-cp${python_version/\./}-*manylinux*.whl) if [ "$python_version" = "3.12" ]; then - pip install ${wheelfile}[all] pytest "pygame>=2.5" "cffi>=1.17" docopt commonmark optuna + pip install ${wheelfile}[all] pytest "pygame>=2.5" "cffi>=1.17" docopt commonmark optuna graph-jsp-env pytest-cases else - pip install ${wheelfile}[all] pytest gymnasium[classic-control] docopt commonmark optuna + pip install ${wheelfile}[all] pytest gymnasium[classic-control] docopt commonmark optuna graph-jsp-env pytest-cases fi - name: Test with pytest diff --git a/examples/gnn_sb3_jsp.py b/examples/gnn_sb3_jsp.py new file mode 100644 index 0000000000..b77b5cdd24 --- /dev/null +++ b/examples/gnn_sb3_jsp.py @@ -0,0 +1,137 @@ +from typing import Any + +import numpy as np +from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv +from gymnasium.spaces import Box, Graph, GraphInstance + +from skdecide.core import Space, TransitionOutcome, Value +from skdecide.domains import Domain +from skdecide.hub.domain.gym import GymDomain +from skdecide.hub.solver.stable_baselines import StableBaseline +from skdecide.hub.solver.stable_baselines.gnn import GraphPPO +from skdecide.hub.space.gym import GymSpace, ListSpace +from skdecide.utils import rollout + +# JSP graph env + + +class D(Domain): + T_state = GraphInstance # Type of states + T_observation = T_state # Type of observations + T_event = int # Type of events + T_value = float # Type of transition values (rewards or costs) + T_info = None # Type of additional information in environment outcome + + +class GraphJspDomain(GymDomain, D): + _gym_env: DisjunctiveGraphJspEnv + + def __init__(self, gym_env): + GymDomain.__init__(self, gym_env=gym_env) + if self._gym_env.normalize_observation_space: + self.n_nodes_features = gym_env.n_machines + 1 + else: + self.n_nodes_features = 2 + + def _state_step( + self, action: D.T_event + ) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]: + outcome = super()._state_step(action=action) + outcome.state = self._np_state2graph_state(outcome.state) + return outcome + + def _get_applicable_actions_from( + self, memory: D.T_memory[D.T_state] + ) -> D.T_agent[Space[D.T_event]]: + return ListSpace(np.nonzero(self._gym_env.valid_action_mask())[0]) + + def _is_applicable_action_from( + self, action: D.T_agent[D.T_event], memory: D.T_memory[D.T_state] + ) -> bool: + return self._gym_env.valid_action_mask()[action] + + def _state_reset(self) -> D.T_state: + return self._np_state2graph_state(super()._state_reset()) + + def _get_observation_space_(self) -> Space[D.T_observation]: + if self._gym_env.normalize_observation_space: + original_graph_space = Graph( + node_space=Box( + low=0.0, high=1.0, shape=(self.n_nodes_features,), dtype=np.float_ + ), + edge_space=Box(low=0, high=1.0, dtype=np.float_), + ) + + else: + original_graph_space = Graph( + node_space=Box( + low=np.array([0, 0]), + high=np.array( + [ + self._gym_env.n_machines, + self._gym_env.longest_processing_time, + ] + ), + dtype=np.int_, + ), + edge_space=Box( + low=0, high=self._gym_env.longest_processing_time, dtype=np.int_ + ), + ) + return GymSpace(original_graph_space) + + def _np_state2graph_state(self, np_state: np.array) -> GraphInstance: + if not self._gym_env.normalize_observation_space: + np_state = np_state.astype(np.int_) + + nodes = np_state[:, -self.n_nodes_features :] + adj = np_state[:, : -self.n_nodes_features] + edge_starts_ends = adj.nonzero() + edge_links = np.transpose(edge_starts_ends) + edges = adj[edge_starts_ends][:, None] + + return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links) + + def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: + return self._gym_env.render(**kwargs) + + +jsp = np.array( + [ + [ + [0, 1, 2], # machines for job 0 + [0, 2, 1], # machines for job 1 + [0, 1, 2], # machines for job 2 + ], + [ + [3, 2, 2], # task durations of job 0 + [2, 1, 4], # task durations of job 1 + [0, 4, 3], # task durations of job 2 + ], + ] +) + + +jsp_env = DisjunctiveGraphJspEnv( + jps_instance=jsp, + perform_left_shift_if_possible=True, + normalize_observation_space=False, + flat_observation_space=False, + action_mode="task", +) + +# random rollout +domain = GraphJspDomain(gym_env=jsp_env) +rollout(domain=domain, max_steps=jsp_env.total_tasks_without_dummies, num_episodes=1) + +# solve with sb3-GraphPPO +domain_factory = lambda: GraphJspDomain(gym_env=jsp_env) +with StableBaseline( + domain_factory=domain_factory, + algo_class=GraphPPO, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, +) as solver: + + solver.solve() + rollout(domain=domain_factory(), solver=solver, max_steps=100, num_episodes=1) diff --git a/examples/gnn_sb3_maze.py b/examples/gnn_sb3_maze.py new file mode 100644 index 0000000000..bc53fb8011 --- /dev/null +++ b/examples/gnn_sb3_maze.py @@ -0,0 +1,203 @@ +from typing import Any, Optional + +import numpy as np +import numpy.typing as npt +from gymnasium.spaces import Box, Discrete, Graph, GraphInstance + +from skdecide.builders.domain import Renderable, UnrestrictedActions +from skdecide.core import Space, Value +from skdecide.domains import DeterministicPlanningDomain +from skdecide.hub.domain.maze import Maze +from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State +from skdecide.hub.solver.stable_baselines import StableBaseline +from skdecide.hub.solver.stable_baselines.gnn import GraphPPO +from skdecide.hub.space.gym import GymSpace, ListSpace +from skdecide.utils import rollout + + +class D(DeterministicPlanningDomain, UnrestrictedActions, Renderable): + T_state = GraphInstance # Type of states + T_observation = T_state # Type of observations + T_event = Action # Type of events + T_value = float # Type of transition values (rewards or costs) + T_predicate = bool # Type of logical checks + T_info = ( + None # Type of additional information given as part of an environment outcome + ) + + +class GraphMaze(D): + def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): + self.discrete_features = discrete_features + self.maze_domain = Maze(maze_str=maze_str) + np_wall = np.array(self.maze_domain._maze) + np_y = np.array( + [ + [(i) for j in range(self.maze_domain._num_cols)] + for i in range(self.maze_domain._num_rows) + ] + ) + np_x = np.array( + [ + [(j) for j in range(self.maze_domain._num_cols)] + for i in range(self.maze_domain._num_rows) + ] + ) + walls = np_wall.ravel() + coords = [i for i in zip(np_y.ravel(), np_x.ravel())] + np_node_id = np.reshape(range(len(walls)), np_wall.shape) + edge_links = [] + edges = [] + for i in range(self.maze_domain._num_rows): + for j in range(self.maze_domain._num_cols): + current_coord = (i, j) + if i > 0: + next_coord = (i - 1, j) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if i < self.maze_domain._num_rows - 1: + next_coord = (i + 1, j) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if j > 0: + next_coord = (i, j - 1) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if j < self.maze_domain._num_cols - 1: + next_coord = (i, j + 1) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + self.edges = np.array(edges) + self.edge_links = np.array(edge_links) + self.walls = walls + self.node_ids = np_node_id + self.coords = coords + + def _mazestate2graph(self, state: State) -> GraphInstance: + x, y = state + agent_presence = np.zeros(self.walls.shape, dtype=self.walls.dtype) + agent_presence[self.node_ids[y, x]] = 1 + nodes = np.stack([self.walls, agent_presence], axis=-1) + if self.discrete_features: + return GraphInstance( + nodes=nodes, edges=self.edges, edge_links=self.edge_links + ) + else: + return GraphInstance( + nodes=nodes, edges=self.edges[:, None], edge_links=self.edge_links + ) + + def _graph2mazestate(self, graph: GraphInstance) -> State: + y, x = self.coords[graph.nodes[:, 1].nonzero()[0][0]] + return State(x=x, y=y) + + def _is_terminal(self, state: D.T_state) -> D.T_predicate: + return self.maze_domain._is_terminal(self._graph2mazestate(state)) + + def _get_next_state(self, memory: D.T_state, action: D.T_event) -> D.T_state: + maze_memory = self._graph2mazestate(memory) + maze_next_state = self.maze_domain._get_next_state( + memory=maze_memory, action=action + ) + return self._mazestate2graph(maze_next_state) + + def _get_transition_value( + self, + memory: D.T_state, + action: D.T_event, + next_state: Optional[D.T_state] = None, + ) -> Value[D.T_value]: + maze_memory = self._graph2mazestate(memory) + if next_state is None: + maze_next_state = None + else: + maze_next_state = self._graph2mazestate(next_state) + return self.maze_domain._get_transition_value( + memory=maze_memory, action=action, next_state=maze_next_state + ) + + def _get_action_space_(self) -> Space[D.T_event]: + return self.maze_domain._get_action_space_() + + def _get_goals_(self) -> Space[D.T_observation]: + return ListSpace([self._mazestate2graph(self.maze_domain._goal)]) + + def _is_goal( + self, observation: D.T_agent[D.T_observation] + ) -> D.T_agent[D.T_predicate]: + return self.maze_domain._is_goal(self._graph2mazestate(observation)) + + def _get_initial_state_(self) -> D.T_state: + return self._mazestate2graph(self.maze_domain._get_initial_state_()) + + def _get_observation_space_(self) -> Space[D.T_observation]: + if self.discrete_features: + return GymSpace( + Graph( + node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), + edge_space=Discrete(2), + ) + ) + else: + return GymSpace( + Graph( + node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), + edge_space=Box(low=0, high=1, shape=(1,), dtype=self.edges.dtype), + ) + ) + + def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any: + maze_memory = self._graph2mazestate(memory) + self.maze_domain._render_from(memory=maze_memory, **kwargs) + + +MAZE = """ ++-+-+-+-+o+-+-+--+-+-+ +| | | | ++ + + +-+-+-+ +--+ + + +| | | | | | | | ++ +-+-+ +-+ + + -+ +-+ +| | | | | | | ++ + + + + + + +--+ +-+ +| | | | | | ++-+-+-+-+-+-+-+ -+-+ + +| | | | ++ +-+-+-+-+ + +--+-+ + +| | | | ++ + + +-+ +-+ +--+-+-+ +| | | | | | ++ +-+-+ + +-+ + -+-+ + +| | | | | | | | ++-+ +-+ + + + +--+ + + +| | | | | | | ++ +-+ +-+-+-+-+ -+ + + +| | | | | ++-+-+-+-+-+x+-+--+-+-+ +""" + +domain = GraphMaze(maze_str=MAZE, discrete_features=True) +assert domain.reset() in domain.get_observation_space() + +# random rollout +rollout(domain=domain, max_steps=50, num_episodes=1) + +# solve with sb3-PPO-GNN +domain_factory = lambda: GraphMaze(maze_str=MAZE) +max_steps = domain.maze_domain._num_cols * domain.maze_domain._num_rows +with StableBaseline( + domain_factory=domain_factory, + algo_class=GraphPPO, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, +) as solver: + + solver.solve() + rollout(domain=domain_factory(), solver=solver, max_steps=max_steps, num_episodes=1) diff --git a/poetry.lock b/poetry.lock index 4a4644b702..1cd3b25a20 100644 --- a/poetry.lock +++ b/poetry.lock @@ -11,6 +11,115 @@ files = [ {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, ] +[[package]] +name = "aiohappyeyeballs" +version = "2.4.3" +description = "Happy Eyeballs for asyncio" +optional = true +python-versions = ">=3.8" +files = [ + {file = "aiohappyeyeballs-2.4.3-py3-none-any.whl", hash = "sha256:8a7a83727b2756f394ab2895ea0765a0a8c475e3c71e98d43d76f22b4b435572"}, + {file = "aiohappyeyeballs-2.4.3.tar.gz", hash = "sha256:75cf88a15106a5002a8eb1dab212525c00d1f4c0fa96e551c9fbe6f09a621586"}, +] + +[[package]] +name = "aiohttp" +version = "3.11.6" +description = "Async http client/server framework (asyncio)" +optional = true +python-versions = ">=3.9" +files = [ + {file = "aiohttp-3.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7510b3ca2275691875ddf072a5b6cd129278d11fe09301add7d292fc8d3432de"}, + {file = "aiohttp-3.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfab0d2c3380c588fc925168533edb21d3448ad76c3eadc360ff963019161724"}, + {file = "aiohttp-3.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf02dba0f342f3a8228f43fae256aafc21c4bc85bffcf537ce4582e2b1565188"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92daedf7221392e7a7984915ca1b0481a94c71457c2f82548414a41d65555e70"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2274a7876e03429e3218589a6d3611a194bdce08c3f1e19962e23370b47c0313"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8a2e1eae2d2f62f3660a1591e16e543b2498358593a73b193006fb89ee37abc6"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:978ec3fb0a42efcd98aae608f58c6cfcececaf0a50b4e86ee3ea0d0a574ab73b"}, + {file = "aiohttp-3.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a51f87b27d9219ed4e202ed8d6f1bb96f829e5eeff18db0d52f592af6de6bdbf"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:04d1a02a669d26e833c8099992c17f557e3b2fdb7960a0c455d7b1cbcb05121d"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3679d5fcbc7f1ab518ab4993f12f80afb63933f6afb21b9b272793d398303b98"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a4b24e03d04893b5c8ec9cd5f2f11dc9c8695c4e2416d2ac2ce6c782e4e5ffa5"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:d9abdfd35ecff1c95f270b7606819a0e2de9e06fa86b15d9080de26594cf4c23"}, + {file = "aiohttp-3.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8b5c3e7928a0ad80887a5eba1c1da1830512ddfe7394d805badda45c03db3109"}, + {file = "aiohttp-3.11.6-cp310-cp310-win32.whl", hash = "sha256:913dd9e9378f3c38aeb5c4fb2b8383d6490bc43f3b427ae79f2870651ae08f22"}, + {file = "aiohttp-3.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:4ac26d482c2000c3a59bf757a77adc972828c9d4177b4bd432a46ba682ca7271"}, + {file = "aiohttp-3.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:26ac4c960ea8debf557357a172b3ef201f2236a462aefa1bc17683a75483e518"}, + {file = "aiohttp-3.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8b1f13ebc99fb98c7c13057b748f05224ccc36d17dee18136c695ef23faaf4ff"}, + {file = "aiohttp-3.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4679f1a47516189fab1774f7e45a6c7cac916224c91f5f94676f18d0b64ab134"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74491fdb3d140ff561ea2128cb7af9ba0a360067ee91074af899c9614f88a18f"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f51e1a90412d387e62aa2d243998c5eddb71373b199d811e6ed862a9f34f9758"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:72ab89510511c3bb703d0bb5504787b11e0ed8be928ed2a7cf1cda9280628430"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6681c9e046d99646e8059266688374a063da85b2e4c0ebfa078cda414905d080"}, + {file = "aiohttp-3.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a17f8a6d3ab72cbbd137e494d1a23fbd3ea973db39587941f32901bb3c5c350"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:867affc7612a314b95f74d93aac550ce0909bc6f0b6c658cc856890f4d326542"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:00d894ebd609d5a423acef885bd61e7f6a972153f99c5b3ea45fc01fe909196c"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:614c87be9d0d64477d1e4b663bdc5d1534fc0a7ebd23fb08347ab9fd5fe20fd7"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:533ed46cf772f28f3bffae81c0573d916a64dee590b5dfaa3f3d11491da05b95"}, + {file = "aiohttp-3.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:589884cfbc09813afb1454816b45677e983442e146183143f988f7f5a040791a"}, + {file = "aiohttp-3.11.6-cp311-cp311-win32.whl", hash = "sha256:1da63633ba921669eec3d7e080459d4ceb663752b3dafb2f31f18edd248d2170"}, + {file = "aiohttp-3.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:d778ddda09622e7d83095cc8051698a0084c155a1474bfee9bac27d8613dbc31"}, + {file = "aiohttp-3.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:943a952df105a5305257984e7a1f5c2d0fd8564ff33647693c4d07eb2315446d"}, + {file = "aiohttp-3.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d24ec28b7658970a1f1d98608d67f88376c7e503d9d45ff2ba1949c09f2b358c"}, + {file = "aiohttp-3.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6720e809a660fdb9bec7c168c582e11cfedce339af0a5ca847a5d5b588dce826"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4252d30da0ada6e6841b325869c7ef5104b488e8dd57ec439892abbb8d7b3615"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f65f43ff01b238aa0b5c47962c83830a49577efe31bd37c1400c3d11d8a32835"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dc5933f6c9b26404444d36babb650664f984b8e5fa0694540e7b7315d11a4ff"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bf546ba0c029dfffc718c4b67748687fd4f341b07b7c8f1719d6a3a46164798"}, + {file = "aiohttp-3.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c351d05bbeae30c088009c0bb3b17dda04fd854f91cc6196c448349cc98f71c3"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:10499079b063576fad1597898de3f9c0a2ce617c19cc7cd6b62fdcff6b408bf7"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:442ee82eda47dd59798d6866ce020fb8d02ea31ac9ac82b3d719ed349e6a9d52"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:86fce9127bc317119b34786d9e9ae8af4508a103158828a535f56d201da6ab19"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:973d26a5537ce5d050302eb3cd876457451745b1da0624cbb483217970e12567"}, + {file = "aiohttp-3.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:532b8f038a4e001137d3600cea5d3439d1881df41bdf44d0f9651264d562fdf0"}, + {file = "aiohttp-3.11.6-cp312-cp312-win32.whl", hash = "sha256:4863c59f748dbe147da82b389931f2a676aebc9d3419813ed5ca32d057c9cb32"}, + {file = "aiohttp-3.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:5d7f481f82c18ac1f7986e31ba6eea9be8b2e2c86f1ef035b6866179b6c5dd68"}, + {file = "aiohttp-3.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:40f502350496ba4c6820816d3164f8a0297b9aa4e95d910da31beb189866a9df"}, + {file = "aiohttp-3.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9072669b0bffb40f1f6977d0b5e8a296edc964f9cefca3a18e68649c214d0ce3"}, + {file = "aiohttp-3.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:518160ecf4e6ffd61715bc9173da0925fcce44ae6c7ca3d3f098fe42585370fb"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f69cc1b45115ac44795b63529aa5caa9674be057f11271f65474127b24fc1ce6"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c6be90a6beced41653bda34afc891617c6d9e8276eef9c183f029f851f0a3c3d"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00c22fe2486308770d22ef86242101d7b0f1e1093ce178f2358f860e5149a551"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2607ebb783e3aeefa017ec8f34b506a727e6b6ab2c4b037d65f0bc7151f4430a"}, + {file = "aiohttp-3.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5f761d6819870c2a8537f75f3e2fc610b163150cefa01f9f623945840f601b2c"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e44d1bc6c88f5234115011842219ba27698a5f2deee245c963b180080572aaa2"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7e0cb6a1b1f499cb2aa0bab1c9f2169ad6913c735b7447e058e0c29c9e51c0b5"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a76b4d4ca34254dca066acff2120811e2a8183997c135fcafa558280f2cc53f3"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:69051c1e45fb18c0ae4d39a075532ff0b015982e7997f19eb5932eb4a3e05c17"}, + {file = "aiohttp-3.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:aff2ed18274c0bfe0c1d772781c87d5ca97ae50f439729007cec9644ee9b15fe"}, + {file = "aiohttp-3.11.6-cp313-cp313-win32.whl", hash = "sha256:2fbea25f2d44df809a46414a8baafa5f179d9dda7e60717f07bded56300589b3"}, + {file = "aiohttp-3.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:f77bc29a465c0f9f6573d1abe656d385fa673e34efe615bd4acc50899280ee47"}, + {file = "aiohttp-3.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:de6123b298d17bca9e53581f50a275b36e10d98e8137eb743ce69ee766dbdfe9"}, + {file = "aiohttp-3.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a10200f705f4fff00e148b7f41e5d1d929c7cd4ac523c659171a0ea8284cd6fb"}, + {file = "aiohttp-3.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b7776ef6901b54dd557128d96c71e412eec0c39ebc07567e405ac98737995aad"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e5c2a55583cd91936baf73d223807bb93ace6eb1fe54424782690f2707162ab"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b032bd6cf7422583bf44f233f4a1489fee53c6d35920123a208adc54e2aba41e"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04fe2d99acbc5cf606f75d7347bf3a027c24c27bc052d470fb156f4cfcea5739"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84a79c366375c2250934d1238abe5d5ea7754c823a1c7df0c52bf0a2bfded6a9"}, + {file = "aiohttp-3.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c33cbbe97dc94a34d1295a7bb68f82727bcbff2b284f73ae7e58ecc05903da97"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:19e4fb9ac727834b003338dcdd27dcfe0de4fb44082b01b34ed0ab67c3469fc9"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a97f6b2afbe1d27220c0c14ea978e09fb4868f462ef3d56d810d206bd2e057a2"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c3f7afeea03a9bc49be6053dfd30809cd442cc12627d6ca08babd1c1f9e04ccf"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:0d10967600ce5bb69ddcb3e18d84b278efb5199d8b24c3c71a4959c2f08acfd0"}, + {file = "aiohttp-3.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:60f2f631b9fe7aa321fa0f0ff3f5d8b9f7f9b72afd4eecef61c33cf1cfea5d58"}, + {file = "aiohttp-3.11.6-cp39-cp39-win32.whl", hash = "sha256:4d2b75333deb5c5f61bac5a48bba3dbc142eebbd3947d98788b6ef9cc48628ae"}, + {file = "aiohttp-3.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:8908c235421972a2e02abcef87d16084aabfe825d14cc9a1debd609b3cfffbea"}, + {file = "aiohttp-3.11.6.tar.gz", hash = "sha256:fd9f55c1b51ae1c20a1afe7216a64a88d38afee063baa23c7fce03757023c999"}, +] + +[package.dependencies] +aiohappyeyeballs = ">=2.3.0" +aiosignal = ">=1.1.2" +async-timeout = {version = ">=4.0,<6.0", markers = "python_version < \"3.11\""} +attrs = ">=17.3.0" +frozenlist = ">=1.1.1" +multidict = ">=4.5,<7.0" +propcache = ">=0.2.0" +yarl = ">=1.17.0,<2.0" + +[package.extras] +speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] + [[package]] name = "aiosignal" version = "1.3.1" @@ -69,6 +178,17 @@ files = [ six = ">=1.6.1,<2.0" wheel = ">=0.23.0,<1.0" +[[package]] +name = "async-timeout" +version = "5.0.1" +description = "Timeout context manager for asyncio programs" +optional = true +python-versions = ">=3.8" +files = [ + {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, + {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, +] + [[package]] name = "atomicwrites" version = "1.4.1" @@ -496,7 +616,7 @@ cffi = "*" name = "cloudpickle" version = "3.1.0" description = "Pickler class to extend the standard pickle.Pickler functionality" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "cloudpickle-3.1.0-py3-none-any.whl", hash = "sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e"}, @@ -827,6 +947,20 @@ files = [ {file = "debugpy-1.8.7.zip", hash = "sha256:18b8f731ed3e2e1df8e9cdaa23fb1fc9c24e570cd0081625308ec51c82efe42e"}, ] +[[package]] +name = "decopatch" +version = "1.4.10" +description = "Create decorators easily in python." +optional = false +python-versions = "*" +files = [ + {file = "decopatch-1.4.10-py2.py3-none-any.whl", hash = "sha256:e151f7f93de2b1b3fd3f3272dcc7cefd1a69f68ec1c2d8e288ecd9deb36dc5f7"}, + {file = "decopatch-1.4.10.tar.gz", hash = "sha256:957f49c93f4150182c23f8fb51d13bb3213e0f17a79e09c8cca7057598b55720"}, +] + +[package.dependencies] +makefun = ">=1.5.0" + [[package]] name = "decorator" version = "5.1.1" @@ -1080,7 +1214,7 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth name = "farama-notifications" version = "0.0.4" description = "Notifications for all Farama Foundation maintained libraries." -optional = true +optional = false python-versions = "*" files = [ {file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"}, @@ -1365,6 +1499,32 @@ files = [ [package.dependencies] six = "*" +[[package]] +name = "graph-jsp-env" +version = "0.3.3" +description = "A flexible enviorment for job shop scheduling using the disjunctive graph apporach." +optional = false +python-versions = ">=3.9" +files = [ + {file = "graph-jsp-env-0.3.3.tar.gz", hash = "sha256:11f4f628a121237a9a073a65e568473d4c58aa83156e87e6221cff90a980411a"}, + {file = "graph_jsp_env-0.3.3-py3-none-any.whl", hash = "sha256:808f30550d07ba9956cb3985cb3ff7dfd37826a4da34c232c5c3d813173e4eb0"}, +] + +[package.dependencies] +gymnasium = "*" +kaleido = "*" +matplotlib = "*" +networkx = ">=3" +numpy = "*" +opencv-python = "*" +pandas = "*" +plotly = "*" +rich = "*" + +[package.extras] +dev = ["black", "bumpver", "flake8", "isort", "mypy", "pip-tools", "pytest", "pytest-cov", "stable-baselines3"] +testing = ["flake8 (>=3.9)", "mypy (>=0.910)", "pytest (>=6.0)", "pytest-cov (>=2.0)", "tox (>=3.24)"] + [[package]] name = "grpcio" version = "1.67.0" @@ -1468,7 +1628,7 @@ matrixapi = ["numpy", "scipy"] name = "gymnasium" version = "0.28.1" description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "gymnasium-0.28.1-py3-none-any.whl", hash = "sha256:7bc9a5bce1022f997d1dbc152fc91d1ac977bad9cc7794cdc25437010867cabf"}, @@ -1735,8 +1895,8 @@ importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} jaxlib = ">=0.4.27,<=0.4.30" ml-dtypes = ">=0.2.0" numpy = [ - {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.22", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] opt-einsum = "*" @@ -1758,7 +1918,7 @@ tpu = ["jaxlib (==0.4.30)", "libtpu-nightly (==0.1.dev20240617)", "requests"] name = "jax-jumpy" version = "1.0.0" description = "Common backend for Jax or Numpy." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "jax-jumpy-1.0.0.tar.gz", hash = "sha256:195fb955cc4c2b7f0b1453e3cb1fb1c414a51a407ffac7a51e69a73cb30d59ad"}, @@ -1951,6 +2111,16 @@ traitlets = ">=5.3" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"] +[[package]] +name = "kaleido" +version = "0.2.1.post1" +description = "Static image export for web-based visualization libraries with zero dependencies" +optional = false +python-versions = "*" +files = [ + {file = "kaleido-0.2.1.post1-py2.py3-none-manylinux2014_armv7l.whl", hash = "sha256:d313940896c24447fc12c74f60d46ea826195fc991f58569a6e73864d53e5c20"}, +] + [[package]] name = "keras" version = "3.6.0" @@ -2213,6 +2383,17 @@ docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"] flake8 = ["flake8"] tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] +[[package]] +name = "makefun" +version = "1.15.6" +description = "Small library to dynamically create python functions." +optional = false +python-versions = "*" +files = [ + {file = "makefun-1.15.6-py2.py3-none-any.whl", hash = "sha256:e69b870f0bb60304765b1e3db576aaecf2f9b3e5105afe8cfeff8f2afe6ad067"}, + {file = "makefun-1.15.6.tar.gz", hash = "sha256:26bc63442a6182fb75efed8b51741dd2d1db2f176bec8c64e20a586256b8f149"}, +] + [[package]] name = "markdown" version = "3.7" @@ -2235,7 +2416,7 @@ testing = ["coverage", "pyyaml"] name = "markdown-it-py" version = "3.0.0" description = "Python port of markdown-it. Markdown parsing, done right!" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, @@ -2407,7 +2588,7 @@ traitlets = "*" name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, @@ -2456,9 +2637,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, - {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">1.20", markers = "python_version < \"3.10\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -2566,6 +2747,110 @@ files = [ {file = "msgpack-1.1.0.tar.gz", hash = "sha256:dd432ccc2c72b914e4cb77afce64aab761c1137cc698be3984eee260bcb2896e"}, ] +[[package]] +name = "multidict" +version = "6.1.0" +description = "multidict implementation" +optional = true +python-versions = ">=3.8" +files = [ + {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a114d03b938376557927ab23f1e950827c3b893ccb94b62fd95d430fd0e5cf53"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1c416351ee6271b2f49b56ad7f308072f6f44b37118d69c2cad94f3fa8a40d5"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b5d83030255983181005e6cfbac1617ce9746b219bc2aad52201ad121226581"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e97b5e938051226dc025ec80980c285b053ffb1e25a3db2a3aa3bc046bf7f56"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d618649d4e70ac6efcbba75be98b26ef5078faad23592f9b51ca492953012429"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10524ebd769727ac77ef2278390fb0068d83f3acb7773792a5080f2b0abf7748"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ff3827aef427c89a25cc96ded1759271a93603aba9fb977a6d264648ebf989db"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:06809f4f0f7ab7ea2cabf9caca7d79c22c0758b58a71f9d32943ae13c7ace056"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f179dee3b863ab1c59580ff60f9d99f632f34ccb38bf67a33ec6b3ecadd0fd76"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:aaed8b0562be4a0876ee3b6946f6869b7bcdb571a5d1496683505944e268b160"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3c8b88a2ccf5493b6c8da9076fb151ba106960a2df90c2633f342f120751a9e7"}, + {file = "multidict-6.1.0-cp310-cp310-win32.whl", hash = "sha256:4a9cb68166a34117d6646c0023c7b759bf197bee5ad4272f420a0141d7eb03a0"}, + {file = "multidict-6.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:20b9b5fbe0b88d0bdef2012ef7dee867f874b72528cf1d08f1d59b0e3850129d"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3efe2c2cb5763f2f1b275ad2bf7a287d3f7ebbef35648a9726e3b69284a4f3d6"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7053d3b0353a8b9de430a4f4b4268ac9a4fb3481af37dfe49825bf45ca24156"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27e5fc84ccef8dfaabb09d82b7d179c7cf1a3fbc8a966f8274fcb4ab2eb4cadb"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e2b90b43e696f25c62656389d32236e049568b39320e2735d51f08fd362761b"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d83a047959d38a7ff552ff94be767b7fd79b831ad1cd9920662db05fec24fe72"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1a9dd711d0877a1ece3d2e4fea11a8e75741ca21954c919406b44e7cf971304"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec2abea24d98246b94913b76a125e855eb5c434f7c46546046372fe60f666351"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4867cafcbc6585e4b678876c489b9273b13e9fff9f6d6d66add5e15d11d926cb"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5b48204e8d955c47c55b72779802b219a39acc3ee3d0116d5080c388970b76e3"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8fff389528cad1618fb4b26b95550327495462cd745d879a8c7c2115248e399"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a7a9541cd308eed5e30318430a9c74d2132e9a8cb46b901326272d780bf2d423"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:da1758c76f50c39a2efd5e9859ce7d776317eb1dd34317c8152ac9251fc574a3"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c943a53e9186688b45b323602298ab727d8865d8c9ee0b17f8d62d14b56f0753"}, + {file = "multidict-6.1.0-cp311-cp311-win32.whl", hash = "sha256:90f8717cb649eea3504091e640a1b8568faad18bd4b9fcd692853a04475a4b80"}, + {file = "multidict-6.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:82176036e65644a6cc5bd619f65f6f19781e8ec2e5330f51aa9ada7504cc1926"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b04772ed465fa3cc947db808fa306d79b43e896beb677a56fb2347ca1a49c1fa"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6180c0ae073bddeb5a97a38c03f30c233e0a4d39cd86166251617d1bbd0af436"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:071120490b47aa997cca00666923a83f02c7fbb44f71cf7f136df753f7fa8761"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50b3a2710631848991d0bf7de077502e8994c804bb805aeb2925a981de58ec2e"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b58c621844d55e71c1b7f7c498ce5aa6985d743a1a59034c57a905b3f153c1ef"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b6d90641869892caa9ca42ff913f7ff1c5ece06474fbd32fb2cf6834726c95"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b820514bfc0b98a30e3d85462084779900347e4d49267f747ff54060cc33925"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10a9b09aba0c5b48c53761b7c720aaaf7cf236d5fe394cd399c7ba662d5f9966"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e16bf3e5fc9f44632affb159d30a437bfe286ce9e02754759be5536b169b305"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:76f364861c3bfc98cbbcbd402d83454ed9e01a5224bb3a28bf70002a230f73e2"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:820c661588bd01a0aa62a1283f20d2be4281b086f80dad9e955e690c75fb54a2"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:0e5f362e895bc5b9e67fe6e4ded2492d8124bdf817827f33c5b46c2fe3ffaca6"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ec660d19bbc671e3a6443325f07263be452c453ac9e512f5eb935e7d4ac28b3"}, + {file = "multidict-6.1.0-cp312-cp312-win32.whl", hash = "sha256:58130ecf8f7b8112cdb841486404f1282b9c86ccb30d3519faf301b2e5659133"}, + {file = "multidict-6.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:188215fc0aafb8e03341995e7c4797860181562380f81ed0a87ff455b70bf1f1"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d569388c381b24671589335a3be6e1d45546c2988c2ebe30fdcada8457a31008"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:052e10d2d37810b99cc170b785945421141bf7bb7d2f8799d431e7db229c385f"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f90c822a402cb865e396a504f9fc8173ef34212a342d92e362ca498cad308e28"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b225d95519a5bf73860323e633a664b0d85ad3d5bede6d30d95b35d4dfe8805b"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:23bfd518810af7de1116313ebd9092cb9aa629beb12f6ed631ad53356ed6b86c"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c09fcfdccdd0b57867577b719c69e347a436b86cd83747f179dbf0cc0d4c1f3"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf6bea52ec97e95560af5ae576bdac3aa3aae0b6758c6efa115236d9e07dae44"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57feec87371dbb3520da6192213c7d6fc892d5589a93db548331954de8248fd2"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0c3f390dc53279cbc8ba976e5f8035eab997829066756d811616b652b00a23a3"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:59bfeae4b25ec05b34f1956eaa1cb38032282cd4dfabc5056d0a1ec4d696d3aa"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b2f59caeaf7632cc633b5cf6fc449372b83bbdf0da4ae04d5be36118e46cc0aa"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:37bb93b2178e02b7b618893990941900fd25b6b9ac0fa49931a40aecdf083fe4"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4e9f48f58c2c523d5a06faea47866cd35b32655c46b443f163d08c6d0ddb17d6"}, + {file = "multidict-6.1.0-cp313-cp313-win32.whl", hash = "sha256:3a37ffb35399029b45c6cc33640a92bef403c9fd388acce75cdc88f58bd19a81"}, + {file = "multidict-6.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:e9aa71e15d9d9beaad2c6b9319edcdc0a49a43ef5c0a4c8265ca9ee7d6c67774"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:db7457bac39421addd0c8449933ac32d8042aae84a14911a757ae6ca3eef1392"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d094ddec350a2fb899fec68d8353c78233debde9b7d8b4beeafa70825f1c281a"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5845c1fd4866bb5dd3125d89b90e57ed3138241540897de748cdf19de8a2fca2"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9079dfc6a70abe341f521f78405b8949f96db48da98aeb43f9907f342f627cdc"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3914f5aaa0f36d5d60e8ece6a308ee1c9784cd75ec8151062614657a114c4478"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c08be4f460903e5a9d0f76818db3250f12e9c344e79314d1d570fc69d7f4eae4"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d093be959277cb7dee84b801eb1af388b6ad3ca6a6b6bf1ed7585895789d027d"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3702ea6872c5a2a4eeefa6ffd36b042e9773f05b1f37ae3ef7264b1163c2dcf6"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2090f6a85cafc5b2db085124d752757c9d251548cedabe9bd31afe6363e0aff2"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:f67f217af4b1ff66c68a87318012de788dd95fcfeb24cc889011f4e1c7454dfd"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:189f652a87e876098bbc67b4da1049afb5f5dfbaa310dd67c594b01c10388db6"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:6bb5992037f7a9eff7991ebe4273ea7f51f1c1c511e6a2ce511d0e7bdb754492"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f4c2b9e770c4e393876e35a7046879d195cd123b4f116d299d442b335bcd"}, + {file = "multidict-6.1.0-cp38-cp38-win32.whl", hash = "sha256:e27bbb6d14416713a8bd7aaa1313c0fc8d44ee48d74497a0ff4c3a1b6ccb5167"}, + {file = "multidict-6.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:22f3105d4fb15c8f57ff3959a58fcab6ce36814486500cd7485651230ad4d4ef"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:4e18b656c5e844539d506a0a06432274d7bd52a7487e6828c63a63d69185626c"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a185f876e69897a6f3325c3f19f26a297fa058c5e456bfcff8015e9a27e83ae1"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab7c4ceb38d91570a650dba194e1ca87c2b543488fe9309b4212694174fd539c"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e617fb6b0b6953fffd762669610c1c4ffd05632c138d61ac7e14ad187870669c"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:16e5f4bf4e603eb1fdd5d8180f1a25f30056f22e55ce51fb3d6ad4ab29f7d96f"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4c035da3f544b1882bac24115f3e2e8760f10a0107614fc9839fd232200b875"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:957cf8e4b6e123a9eea554fa7ebc85674674b713551de587eb318a2df3e00255"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:483a6aea59cb89904e1ceabd2b47368b5600fb7de78a6e4a2c2987b2d256cf30"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:87701f25a2352e5bf7454caa64757642734da9f6b11384c1f9d1a8e699758057"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:682b987361e5fd7a139ed565e30d81fd81e9629acc7d925a205366877d8c8657"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce2186a7df133a9c895dea3331ddc5ddad42cdd0d1ea2f0a51e5d161e4762f28"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9f636b730f7e8cb19feb87094949ba54ee5357440b9658b2a32a5ce4bce53972"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:73eae06aa53af2ea5270cc066dcaf02cc60d2994bbb2c4ef5764949257d10f43"}, + {file = "multidict-6.1.0-cp39-cp39-win32.whl", hash = "sha256:1ca0083e80e791cffc6efce7660ad24af66c8d4079d2a750b29001b53ff59ada"}, + {file = "multidict-6.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:aa466da5b15ccea564bdab9c89175c762bc12825f4659c11227f515cee76fa4a"}, + {file = "multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506"}, + {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} + [[package]] name = "multiprocess" version = "0.70.17" @@ -2945,6 +3230,32 @@ pandas = ">=1.2" pyyaml = ">=5.1" scipy = ">=1.7" +[[package]] +name = "opencv-python" +version = "4.10.0.84" +description = "Wrapper package for OpenCV python bindings." +optional = false +python-versions = ">=3.6" +files = [ + {file = "opencv-python-4.10.0.84.tar.gz", hash = "sha256:72d234e4582e9658ffea8e9cae5b63d488ad06994ef12d81dc303b17472f3526"}, + {file = "opencv_python-4.10.0.84-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:fc182f8f4cda51b45f01c64e4cbedfc2f00aff799debebc305d8d0210c43f251"}, + {file = "opencv_python-4.10.0.84-cp37-abi3-macosx_12_0_x86_64.whl", hash = "sha256:71e575744f1d23f79741450254660442785f45a0797212852ee5199ef12eed98"}, + {file = "opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09a332b50488e2dda866a6c5573ee192fe3583239fb26ff2f7f9ceb0bc119ea6"}, + {file = "opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ace140fc6d647fbe1c692bcb2abce768973491222c067c131d80957c595b71f"}, + {file = "opencv_python-4.10.0.84-cp37-abi3-win32.whl", hash = "sha256:2db02bb7e50b703f0a2d50c50ced72e95c574e1e5a0bb35a8a86d0b35c98c236"}, + {file = "opencv_python-4.10.0.84-cp37-abi3-win_amd64.whl", hash = "sha256:32dbbd94c26f611dc5cc6979e6b7aa1f55a64d6b463cc1dcd3c95505a63e48fe"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, + {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, + {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] + [[package]] name = "opt-einsum" version = "3.4.0" @@ -3195,8 +3506,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" @@ -3382,6 +3693,21 @@ docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-a test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] type = ["mypy (>=1.11.2)"] +[[package]] +name = "plotly" +version = "5.24.1" +description = "An open-source, interactive data visualization library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "plotly-5.24.1-py3-none-any.whl", hash = "sha256:f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089"}, + {file = "plotly-5.24.1.tar.gz", hash = "sha256:dbc8ac8339d248a4bcc36e08a5659bacfe1b079390b8953533f4eb22169b4bae"}, +] + +[package.dependencies] +packaging = "*" +tenacity = ">=6.2.0" + [[package]] name = "pluggy" version = "1.5.0" @@ -3447,6 +3773,113 @@ files = [ [package.dependencies] wcwidth = "*" +[[package]] +name = "propcache" +version = "0.2.0" +description = "Accelerated property cache" +optional = true +python-versions = ">=3.8" +files = [ + {file = "propcache-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c5869b8fd70b81835a6f187c5fdbe67917a04d7e52b6e7cc4e5fe39d55c39d58"}, + {file = "propcache-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:952e0d9d07609d9c5be361f33b0d6d650cd2bae393aabb11d9b719364521984b"}, + {file = "propcache-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:33ac8f098df0585c0b53009f039dfd913b38c1d2edafed0cedcc0c32a05aa110"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97e48e8875e6c13909c800fa344cd54cc4b2b0db1d5f911f840458a500fde2c2"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:388f3217649d6d59292b722d940d4d2e1e6a7003259eb835724092a1cca0203a"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f571aea50ba5623c308aa146eb650eebf7dbe0fd8c5d946e28343cb3b5aad577"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3dfafb44f7bb35c0c06eda6b2ab4bfd58f02729e7c4045e179f9a861b07c9850"}, + {file = "propcache-0.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3ebe9a75be7ab0b7da2464a77bb27febcb4fab46a34f9288f39d74833db7f61"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d2f0d0f976985f85dfb5f3d685697ef769faa6b71993b46b295cdbbd6be8cc37"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:a3dc1a4b165283bd865e8f8cb5f0c64c05001e0718ed06250d8cac9bec115b48"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9e0f07b42d2a50c7dd2d8675d50f7343d998c64008f1da5fef888396b7f84630"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e63e3e1e0271f374ed489ff5ee73d4b6e7c60710e1f76af5f0e1a6117cd26394"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:56bb5c98f058a41bb58eead194b4db8c05b088c93d94d5161728515bd52b052b"}, + {file = "propcache-0.2.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7665f04d0c7f26ff8bb534e1c65068409bf4687aa2534faf7104d7182debb336"}, + {file = "propcache-0.2.0-cp310-cp310-win32.whl", hash = "sha256:7cf18abf9764746b9c8704774d8b06714bcb0a63641518a3a89c7f85cc02c2ad"}, + {file = "propcache-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:cfac69017ef97db2438efb854edf24f5a29fd09a536ff3a992b75990720cdc99"}, + {file = "propcache-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:63f13bf09cc3336eb04a837490b8f332e0db41da66995c9fd1ba04552e516354"}, + {file = "propcache-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:608cce1da6f2672a56b24a015b42db4ac612ee709f3d29f27a00c943d9e851de"}, + {file = "propcache-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:466c219deee4536fbc83c08d09115249db301550625c7fef1c5563a584c9bc87"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc2db02409338bf36590aa985a461b2c96fce91f8e7e0f14c50c5fcc4f229016"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a6ed8db0a556343d566a5c124ee483ae113acc9a557a807d439bcecc44e7dfbb"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91997d9cb4a325b60d4e3f20967f8eb08dfcb32b22554d5ef78e6fd1dda743a2"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c7dde9e533c0a49d802b4f3f218fa9ad0a1ce21f2c2eb80d5216565202acab4"}, + {file = "propcache-0.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffcad6c564fe6b9b8916c1aefbb37a362deebf9394bd2974e9d84232e3e08504"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:97a58a28bcf63284e8b4d7b460cbee1edaab24634e82059c7b8c09e65284f178"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:945db8ee295d3af9dbdbb698cce9bbc5c59b5c3fe328bbc4387f59a8a35f998d"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:39e104da444a34830751715f45ef9fc537475ba21b7f1f5b0f4d71a3b60d7fe2"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:c5ecca8f9bab618340c8e848d340baf68bcd8ad90a8ecd7a4524a81c1764b3db"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:c436130cc779806bdf5d5fae0d848713105472b8566b75ff70048c47d3961c5b"}, + {file = "propcache-0.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:191db28dc6dcd29d1a3e063c3be0b40688ed76434622c53a284e5427565bbd9b"}, + {file = "propcache-0.2.0-cp311-cp311-win32.whl", hash = "sha256:5f2564ec89058ee7c7989a7b719115bdfe2a2fb8e7a4543b8d1c0cc4cf6478c1"}, + {file = "propcache-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:6e2e54267980349b723cff366d1e29b138b9a60fa376664a157a342689553f71"}, + {file = "propcache-0.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:2ee7606193fb267be4b2e3b32714f2d58cad27217638db98a60f9efb5efeccc2"}, + {file = "propcache-0.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:91ee8fc02ca52e24bcb77b234f22afc03288e1dafbb1f88fe24db308910c4ac7"}, + {file = "propcache-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e900bad2a8456d00a113cad8c13343f3b1f327534e3589acc2219729237a2e8"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f52a68c21363c45297aca15561812d542f8fc683c85201df0bebe209e349f793"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e41d67757ff4fbc8ef2af99b338bfb955010444b92929e9e55a6d4dcc3c4f09"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a64e32f8bd94c105cc27f42d3b658902b5bcc947ece3c8fe7bc1b05982f60e89"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55346705687dbd7ef0d77883ab4f6fabc48232f587925bdaf95219bae072491e"}, + {file = "propcache-0.2.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00181262b17e517df2cd85656fcd6b4e70946fe62cd625b9d74ac9977b64d8d9"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6994984550eaf25dd7fc7bd1b700ff45c894149341725bb4edc67f0ffa94efa4"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:56295eb1e5f3aecd516d91b00cfd8bf3a13991de5a479df9e27dd569ea23959c"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:439e76255daa0f8151d3cb325f6dd4a3e93043e6403e6491813bcaaaa8733887"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f6475a1b2ecb310c98c28d271a30df74f9dd436ee46d09236a6b750a7599ce57"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3444cdba6628accf384e349014084b1cacd866fbb88433cd9d279d90a54e0b23"}, + {file = "propcache-0.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4a9d9b4d0a9b38d1c391bb4ad24aa65f306c6f01b512e10a8a34a2dc5675d348"}, + {file = "propcache-0.2.0-cp312-cp312-win32.whl", hash = "sha256:69d3a98eebae99a420d4b28756c8ce6ea5a29291baf2dc9ff9414b42676f61d5"}, + {file = "propcache-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:ad9c9b99b05f163109466638bd30ada1722abb01bbb85c739c50b6dc11f92dc3"}, + {file = "propcache-0.2.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ecddc221a077a8132cf7c747d5352a15ed763b674c0448d811f408bf803d9ad7"}, + {file = "propcache-0.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0e53cb83fdd61cbd67202735e6a6687a7b491c8742dfc39c9e01e80354956763"}, + {file = "propcache-0.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92fe151145a990c22cbccf9ae15cae8ae9eddabfc949a219c9f667877e40853d"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6a21ef516d36909931a2967621eecb256018aeb11fc48656e3257e73e2e247a"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3f88a4095e913f98988f5b338c1d4d5d07dbb0b6bad19892fd447484e483ba6b"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a5b3bb545ead161be780ee85a2b54fdf7092815995661947812dde94a40f6fb"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67aeb72e0f482709991aa91345a831d0b707d16b0257e8ef88a2ad246a7280bf"}, + {file = "propcache-0.2.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c997f8c44ec9b9b0bcbf2d422cc00a1d9b9c681f56efa6ca149a941e5560da2"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2a66df3d4992bc1d725b9aa803e8c5a66c010c65c741ad901e260ece77f58d2f"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:3ebbcf2a07621f29638799828b8d8668c421bfb94c6cb04269130d8de4fb7136"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1235c01ddaa80da8235741e80815ce381c5267f96cc49b1477fdcf8c047ef325"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3947483a381259c06921612550867b37d22e1df6d6d7e8361264b6d037595f44"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:d5bed7f9805cc29c780f3aee05de3262ee7ce1f47083cfe9f77471e9d6777e83"}, + {file = "propcache-0.2.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4a91d44379f45f5e540971d41e4626dacd7f01004826a18cb048e7da7e96544"}, + {file = "propcache-0.2.0-cp313-cp313-win32.whl", hash = "sha256:f902804113e032e2cdf8c71015651c97af6418363bea8d78dc0911d56c335032"}, + {file = "propcache-0.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:8f188cfcc64fb1266f4684206c9de0e80f54622c3f22a910cbd200478aeae61e"}, + {file = "propcache-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:53d1bd3f979ed529f0805dd35ddaca330f80a9a6d90bc0121d2ff398f8ed8861"}, + {file = "propcache-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:83928404adf8fb3d26793665633ea79b7361efa0287dfbd372a7e74311d51ee6"}, + {file = "propcache-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77a86c261679ea5f3896ec060be9dc8e365788248cc1e049632a1be682442063"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:218db2a3c297a3768c11a34812e63b3ac1c3234c3a086def9c0fee50d35add1f"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7735e82e3498c27bcb2d17cb65d62c14f1100b71723b68362872bca7d0913d90"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:20a617c776f520c3875cf4511e0d1db847a076d720714ae35ffe0df3e440be68"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67b69535c870670c9f9b14a75d28baa32221d06f6b6fa6f77a0a13c5a7b0a5b9"}, + {file = "propcache-0.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4569158070180c3855e9c0791c56be3ceeb192defa2cdf6a3f39e54319e56b89"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:db47514ffdbd91ccdc7e6f8407aac4ee94cc871b15b577c1c324236b013ddd04"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_armv7l.whl", hash = "sha256:2a60ad3e2553a74168d275a0ef35e8c0a965448ffbc3b300ab3a5bb9956c2162"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:662dd62358bdeaca0aee5761de8727cfd6861432e3bb828dc2a693aa0471a563"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:25a1f88b471b3bc911d18b935ecb7115dff3a192b6fef46f0bfaf71ff4f12418"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:f60f0ac7005b9f5a6091009b09a419ace1610e163fa5deaba5ce3484341840e7"}, + {file = "propcache-0.2.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:74acd6e291f885678631b7ebc85d2d4aec458dd849b8c841b57ef04047833bed"}, + {file = "propcache-0.2.0-cp38-cp38-win32.whl", hash = "sha256:d9b6ddac6408194e934002a69bcaadbc88c10b5f38fb9307779d1c629181815d"}, + {file = "propcache-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:676135dcf3262c9c5081cc8f19ad55c8a64e3f7282a21266d05544450bffc3a5"}, + {file = "propcache-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:25c8d773a62ce0451b020c7b29a35cfbc05de8b291163a7a0f3b7904f27253e6"}, + {file = "propcache-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:375a12d7556d462dc64d70475a9ee5982465fbb3d2b364f16b86ba9135793638"}, + {file = "propcache-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1ec43d76b9677637a89d6ab86e1fef70d739217fefa208c65352ecf0282be957"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f45eec587dafd4b2d41ac189c2156461ebd0c1082d2fe7013571598abb8505d1"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc092ba439d91df90aea38168e11f75c655880c12782facf5cf9c00f3d42b562"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa1076244f54bb76e65e22cb6910365779d5c3d71d1f18b275f1dfc7b0d71b4d"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:682a7c79a2fbf40f5dbb1eb6bfe2cd865376deeac65acf9beb607505dced9e12"}, + {file = "propcache-0.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e40876731f99b6f3c897b66b803c9e1c07a989b366c6b5b475fafd1f7ba3fb8"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:363ea8cd3c5cb6679f1c2f5f1f9669587361c062e4899fce56758efa928728f8"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:140fbf08ab3588b3468932974a9331aff43c0ab8a2ec2c608b6d7d1756dbb6cb"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e70fac33e8b4ac63dfc4c956fd7d85a0b1139adcfc0d964ce288b7c527537fea"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:b33d7a286c0dc1a15f5fc864cc48ae92a846df287ceac2dd499926c3801054a6"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:f6d5749fdd33d90e34c2efb174c7e236829147a2713334d708746e94c4bde40d"}, + {file = "propcache-0.2.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22aa8f2272d81d9317ff5756bb108021a056805ce63dd3630e27d042c8092798"}, + {file = "propcache-0.2.0-cp39-cp39-win32.whl", hash = "sha256:73e4b40ea0eda421b115248d7e79b59214411109a5bc47d0d48e4c73e3b8fcf9"}, + {file = "propcache-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:9517d5e9e0731957468c29dbfd0f976736a0e55afaea843726e887f36fe017df"}, + {file = "propcache-0.2.0-py3-none-any.whl", hash = "sha256:2ccc28197af5313706511fab3a8b66dcd6da067a1331372c82ea1cb74285e036"}, + {file = "propcache-0.2.0.tar.gz", hash = "sha256:df81779732feb9d01e5d513fad0122efb3d53bbc75f61b2a4f29a020bc985e70"}, +] + [[package]] name = "protobuf" version = "5.28.3" @@ -4071,6 +4504,22 @@ toml = "*" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +[[package]] +name = "pytest-cases" +version = "3.8.6" +description = "Separate test code from test cases in pytest." +optional = false +python-versions = "*" +files = [ + {file = "pytest_cases-3.8.6-py2.py3-none-any.whl", hash = "sha256:564c722492ea7e7ec3ac433fd14070180e65866f49fa35bfe938c0d5d9afba67"}, + {file = "pytest_cases-3.8.6.tar.gz", hash = "sha256:5c24e0ab0cb6f8e802a469b7965906a333d3babb874586ebc56f7e2cbe1a7c44"}, +] + +[package.dependencies] +decopatch = "*" +makefun = ">=1.15.1" +packaging = "*" + [[package]] name = "pytest-cov" version = "2.12.1" @@ -4447,7 +4896,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rich" version = "13.9.3" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -optional = true +optional = false python-versions = ">=3.8.0" files = [ {file = "rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283"}, @@ -4932,6 +5381,21 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tensorboard" version = "2.18.0" @@ -5211,6 +5675,35 @@ typing-extensions = ">=4.8.0" opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.12.0)"] +[[package]] +name = "torch-geometric" +version = "2.6.1" +description = "Graph Neural Network Library for PyTorch" +optional = true +python-versions = ">=3.8" +files = [ + {file = "torch_geometric-2.6.1-py3-none-any.whl", hash = "sha256:8faeb353f9655f7dbec44c5e0b44c721773bdfb279994da96b9b8b12fd30f427"}, + {file = "torch_geometric-2.6.1.tar.gz", hash = "sha256:1f18f9d0fc4d2239d526221e4f22606a4a3895b5d965a9856d27610a3df662c6"}, +] + +[package.dependencies] +aiohttp = "*" +fsspec = "*" +jinja2 = "*" +numpy = "*" +psutil = ">=5.8.0" +pyparsing = "*" +requests = "*" +tqdm = "*" + +[package.extras] +benchmark = ["matplotlib", "networkx", "pandas", "protobuf (<4.21)", "wandb"] +dev = ["ipython", "matplotlib-inline", "pre-commit", "torch_geometric[test]"] +full = ["ase", "captum (<0.7.0)", "graphviz", "h5py", "matplotlib", "networkx", "numba (<0.60.0)", "opt_einsum", "pandas", "pgmpy", "pynndescent", "pytorch-memlab", "rdflib", "rdkit", "scikit-image", "scikit-learn", "scipy", "statsmodels", "sympy", "tabulate", "torch_geometric[graphgym,modelhub]", "torchmetrics", "trimesh"] +graphgym = ["protobuf (<4.21)", "pytorch-lightning (<2.3.0)", "yacs"] +modelhub = ["huggingface_hub"] +test = ["onnx", "onnxruntime", "pytest", "pytest-cov"] + [[package]] name = "tornado" version = "6.4.1" @@ -5550,6 +6043,102 @@ files = [ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, ] +[[package]] +name = "yarl" +version = "1.17.2" +description = "Yet another URL library" +optional = true +python-versions = ">=3.9" +files = [ + {file = "yarl-1.17.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:93771146ef048b34201bfa382c2bf74c524980870bb278e6df515efaf93699ff"}, + {file = "yarl-1.17.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8281db240a1616af2f9c5f71d355057e73a1409c4648c8949901396dc0a3c151"}, + {file = "yarl-1.17.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:170ed4971bf9058582b01a8338605f4d8c849bd88834061e60e83b52d0c76870"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc61b005f6521fcc00ca0d1243559a5850b9dd1e1fe07b891410ee8fe192d0c0"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:871e1b47eec7b6df76b23c642a81db5dd6536cbef26b7e80e7c56c2fd371382e"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3a58a2f2ca7aaf22b265388d40232f453f67a6def7355a840b98c2d547bd037f"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:736bb076f7299c5c55dfef3eb9e96071a795cb08052822c2bb349b06f4cb2e0a"}, + {file = "yarl-1.17.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8fd51299e21da709eabcd5b2dd60e39090804431292daacbee8d3dabe39a6bc0"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:358dc7ddf25e79e1cc8ee16d970c23faee84d532b873519c5036dbb858965795"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:50d866f7b1a3f16f98603e095f24c0eeba25eb508c85a2c5939c8b3870ba2df8"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:8b9c4643e7d843a0dca9cd9d610a0876e90a1b2cbc4c5ba7930a0d90baf6903f"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d63123bfd0dce5f91101e77c8a5427c3872501acece8c90df457b486bc1acd47"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:4e76381be3d8ff96a4e6c77815653063e87555981329cf8f85e5be5abf449021"}, + {file = "yarl-1.17.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:734144cd2bd633a1516948e477ff6c835041c0536cef1d5b9a823ae29899665b"}, + {file = "yarl-1.17.2-cp310-cp310-win32.whl", hash = "sha256:26bfb6226e0c157af5da16d2d62258f1ac578d2899130a50433ffee4a5dfa673"}, + {file = "yarl-1.17.2-cp310-cp310-win_amd64.whl", hash = "sha256:76499469dcc24759399accd85ec27f237d52dec300daaca46a5352fcbebb1071"}, + {file = "yarl-1.17.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:792155279dc093839e43f85ff7b9b6493a8eaa0af1f94f1f9c6e8f4de8c63500"}, + {file = "yarl-1.17.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:38bc4ed5cae853409cb193c87c86cd0bc8d3a70fd2268a9807217b9176093ac6"}, + {file = "yarl-1.17.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4a8c83f6fcdc327783bdc737e8e45b2e909b7bd108c4da1892d3bc59c04a6d84"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c6d5fed96f0646bfdf698b0a1cebf32b8aae6892d1bec0c5d2d6e2df44e1e2d"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:782ca9c58f5c491c7afa55518542b2b005caedaf4685ec814fadfcee51f02493"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff6af03cac0d1a4c3c19e5dcc4c05252411bf44ccaa2485e20d0a7c77892ab6e"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a3f47930fbbed0f6377639503848134c4aa25426b08778d641491131351c2c8"}, + {file = "yarl-1.17.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1fa68a3c921365c5745b4bd3af6221ae1f0ea1bf04b69e94eda60e57958907f"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:187df91395c11e9f9dc69b38d12406df85aa5865f1766a47907b1cc9855b6303"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:93d1c8cc5bf5df401015c5e2a3ce75a5254a9839e5039c881365d2a9dcfc6dc2"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:11d86c6145ac5c706c53d484784cf504d7d10fa407cb73b9d20f09ff986059ef"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:c42774d1d1508ec48c3ed29e7b110e33f5e74a20957ea16197dbcce8be6b52ba"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0c8e589379ef0407b10bed16cc26e7392ef8f86961a706ade0a22309a45414d7"}, + {file = "yarl-1.17.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1056cadd5e850a1c026f28e0704ab0a94daaa8f887ece8dfed30f88befb87bb0"}, + {file = "yarl-1.17.2-cp311-cp311-win32.whl", hash = "sha256:be4c7b1c49d9917c6e95258d3d07f43cfba2c69a6929816e77daf322aaba6628"}, + {file = "yarl-1.17.2-cp311-cp311-win_amd64.whl", hash = "sha256:ac8eda86cc75859093e9ce390d423aba968f50cf0e481e6c7d7d63f90bae5c9c"}, + {file = "yarl-1.17.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:dd90238d3a77a0e07d4d6ffdebc0c21a9787c5953a508a2231b5f191455f31e9"}, + {file = "yarl-1.17.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c74f0b0472ac40b04e6d28532f55cac8090e34c3e81f118d12843e6df14d0909"}, + {file = "yarl-1.17.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4d486ddcaca8c68455aa01cf53d28d413fb41a35afc9f6594a730c9779545876"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25b7e93f5414b9a983e1a6c1820142c13e1782cc9ed354c25e933aebe97fcf2"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3a0baff7827a632204060f48dca9e63fbd6a5a0b8790c1a2adfb25dc2c9c0d50"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:460024cacfc3246cc4d9f47a7fc860e4fcea7d1dc651e1256510d8c3c9c7cde0"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5870d620b23b956f72bafed6a0ba9a62edb5f2ef78a8849b7615bd9433384171"}, + {file = "yarl-1.17.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2941756754a10e799e5b87e2319bbec481ed0957421fba0e7b9fb1c11e40509f"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9611b83810a74a46be88847e0ea616794c406dbcb4e25405e52bff8f4bee2d0a"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:cd7e35818d2328b679a13268d9ea505c85cd773572ebb7a0da7ccbca77b6a52e"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:6b981316fcd940f085f646b822c2ff2b8b813cbd61281acad229ea3cbaabeb6b"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:688058e89f512fb7541cb85c2f149c292d3fa22f981d5a5453b40c5da49eb9e8"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:56afb44a12b0864d17b597210d63a5b88915d680f6484d8d202ed68ade38673d"}, + {file = "yarl-1.17.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:17931dfbb84ae18b287279c1f92b76a3abcd9a49cd69b92e946035cff06bcd20"}, + {file = "yarl-1.17.2-cp312-cp312-win32.whl", hash = "sha256:ff8d95e06546c3a8c188f68040e9d0360feb67ba8498baf018918f669f7bc39b"}, + {file = "yarl-1.17.2-cp312-cp312-win_amd64.whl", hash = "sha256:4c840cc11163d3c01a9d8aad227683c48cd3e5be5a785921bcc2a8b4b758c4f3"}, + {file = "yarl-1.17.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:3294f787a437cb5d81846de3a6697f0c35ecff37a932d73b1fe62490bef69211"}, + {file = "yarl-1.17.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f1e7fedb09c059efee2533119666ca7e1a2610072076926fa028c2ba5dfeb78c"}, + {file = "yarl-1.17.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:da9d3061e61e5ae3f753654813bc1cd1c70e02fb72cf871bd6daf78443e9e2b1"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91c012dceadc695ccf69301bfdccd1fc4472ad714fe2dd3c5ab4d2046afddf29"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f11fd61d72d93ac23718d393d2a64469af40be2116b24da0a4ca6922df26807e"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46c465ad06971abcf46dd532f77560181387b4eea59084434bdff97524444032"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef6eee1a61638d29cd7c85f7fd3ac7b22b4c0fabc8fd00a712b727a3e73b0685"}, + {file = "yarl-1.17.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4434b739a8a101a837caeaa0137e0e38cb4ea561f39cb8960f3b1e7f4967a3fc"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:752485cbbb50c1e20908450ff4f94217acba9358ebdce0d8106510859d6eb19a"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:17791acaa0c0f89323c57da7b9a79f2174e26d5debbc8c02d84ebd80c2b7bff8"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5c6ea72fe619fee5e6b5d4040a451d45d8175f560b11b3d3e044cd24b2720526"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:db5ac3871ed76340210fe028f535392f097fb31b875354bcb69162bba2632ef4"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:7a1606ba68e311576bcb1672b2a1543417e7e0aa4c85e9e718ba6466952476c0"}, + {file = "yarl-1.17.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9bc27dd5cfdbe3dc7f381b05e6260ca6da41931a6e582267d5ca540270afeeb2"}, + {file = "yarl-1.17.2-cp313-cp313-win32.whl", hash = "sha256:52492b87d5877ec405542f43cd3da80bdcb2d0c2fbc73236526e5f2c28e6db28"}, + {file = "yarl-1.17.2-cp313-cp313-win_amd64.whl", hash = "sha256:8e1bf59e035534ba4077f5361d8d5d9194149f9ed4f823d1ee29ef3e8964ace3"}, + {file = "yarl-1.17.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c556fbc6820b6e2cda1ca675c5fa5589cf188f8da6b33e9fc05b002e603e44fa"}, + {file = "yarl-1.17.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f2f44a4247461965fed18b2573f3a9eb5e2c3cad225201ee858726cde610daca"}, + {file = "yarl-1.17.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3a3ede8c248f36b60227eb777eac1dbc2f1022dc4d741b177c4379ca8e75571a"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2654caaf5584449d49c94a6b382b3cb4a246c090e72453493ea168b931206a4d"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0d41c684f286ce41fa05ab6af70f32d6da1b6f0457459a56cf9e393c1c0b2217"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2270d590997445a0dc29afa92e5534bfea76ba3aea026289e811bf9ed4b65a7f"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18662443c6c3707e2fc7fad184b4dc32dd428710bbe72e1bce7fe1988d4aa654"}, + {file = "yarl-1.17.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75ac158560dec3ed72f6d604c81090ec44529cfb8169b05ae6fcb3e986b325d9"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1fee66b32e79264f428dc8da18396ad59cc48eef3c9c13844adec890cd339db5"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:585ce7cd97be8f538345de47b279b879e091c8b86d9dbc6d98a96a7ad78876a3"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:c019abc2eca67dfa4d8fb72ba924871d764ec3c92b86d5b53b405ad3d6aa56b0"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c6e659b9a24d145e271c2faf3fa6dd1fcb3e5d3f4e17273d9e0350b6ab0fe6e2"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:d17832ba39374134c10e82d137e372b5f7478c4cceeb19d02ae3e3d1daed8721"}, + {file = "yarl-1.17.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:bc3003710e335e3f842ae3fd78efa55f11a863a89a72e9a07da214db3bf7e1f8"}, + {file = "yarl-1.17.2-cp39-cp39-win32.whl", hash = "sha256:f5ffc6b7ace5b22d9e73b2a4c7305740a339fbd55301d52735f73e21d9eb3130"}, + {file = "yarl-1.17.2-cp39-cp39-win_amd64.whl", hash = "sha256:48e424347a45568413deec6f6ee2d720de2cc0385019bedf44cd93e8638aa0ed"}, + {file = "yarl-1.17.2-py3-none-any.whl", hash = "sha256:dd7abf4f717e33b7487121faf23560b3a50924f80e4bef62b22dab441ded8f3b"}, + {file = "yarl-1.17.2.tar.gz", hash = "sha256:753eaaa0c7195244c84b5cc159dc8204b7fd99f716f11198f999f2332a86b178"}, +] + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" +propcache = ">=0.2.0" + [[package]] name = "zipp" version = "3.20.2" @@ -5570,11 +6159,11 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -all = ["cartopy", "gymnasium", "joblib", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym", "pyRDDLGym-gurobi", "pyRDDLGym-jax", "pyRDDLGym-rl", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "ray", "rddlrepository", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] +all = ["cartopy", "gymnasium", "joblib", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym", "pyRDDLGym-gurobi", "pyRDDLGym-jax", "pyRDDLGym-rl", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "ray", "rddlrepository", "scipy", "stable-baselines3", "torch-geometric", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] domains = ["cartopy", "gymnasium", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym", "pyRDDLGym-rl", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "rddlrepository", "scipy", "unified-planning"] -solvers = ["gymnasium", "joblib", "numpy", "pyRDDLGym-gurobi", "pyRDDLGym-jax", "ray", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] +solvers = ["gymnasium", "joblib", "numpy", "pyRDDLGym-gurobi", "pyRDDLGym-jax", "ray", "scipy", "stable-baselines3", "torch-geometric", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "353bf8a69a6c318e17e71f400e98ea6d4392883291de54e4c4570cea0a9288f2" +content-hash = "2b17ad02ae15987e4983858da614b1d71e27e1ef9ff9ed1e4b95d32e1ed02813" diff --git a/pyproject.toml b/pyproject.toml index 8fcad3a225..44b62ba7b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ pyRDDLGym-rl = [ pyRDDLGym-jax = { version = ">=0.3", optional = true } pyRDDLGym-gurobi = { version = ">=0.2", optional = true } rddlrepository = {version = ">=2.0", optional = true } +torch-geometric = {version = ">=2.5", optional = true} [tool.poetry.extras] domains = [ @@ -111,7 +112,8 @@ solvers = [ "up-pyperplan", "scipy", "pyRDDLGym-jax", - "pyRDDLGym-gurobi" + "pyRDDLGym-gurobi", + "torch-geometric" ] all = [ "gymnasium", @@ -134,7 +136,8 @@ all = [ "pyRDDLGym-rl", "rddlrepository", "pyRDDLGym-jax", - "pyRDDLGym-gurobi" + "pyRDDLGym-gurobi", + "torch-geometric" ] [tool.poetry.plugins."skdecide.domains"] @@ -199,6 +202,8 @@ commonmark = ">=0.9.1" gymnasium = { version = ">=0.28.1", extras = [ "classic-control", ], optional = true } +graph-jsp-env = { version = ">=0.3.3"} +pytest-cases = {version = ">=3.8"} [tool.pytest.ini_options] minversion = "6.0" diff --git a/skdecide/hub/domain/gym/gym.py b/skdecide/hub/domain/gym/gym.py index 8cae2bea2a..9285bbbcbe 100644 --- a/skdecide/hub/domain/gym/gym.py +++ b/skdecide/hub/domain/gym/gym.py @@ -1175,6 +1175,10 @@ def __init__(self, domain: Domain, unwrap_spaces: bool = True) -> None: domain.get_action_space() ) # assumes all actions are always applicable + @property + def domain(self) -> Domain: + return self._domain + def step(self, action): """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` to reset this environment's state. @@ -1280,6 +1284,8 @@ def unwrapped(self): class AsGymnasiumEnv(EnvCompatibility): """This class wraps a scikit-decide domain as a gymnasium environment.""" + env: AsLegacyGymV21Env + def __init__( self, domain: Domain, @@ -1288,3 +1294,7 @@ def __init__( ) -> None: legacy_env = AsLegacyGymV21Env(domain=domain, unwrap_spaces=unwrap_spaces) super().__init__(old_env=legacy_env, render_mode=render_mode) + + @property + def domain(self) -> Domain: + return self.env.domain diff --git a/skdecide/hub/solver/stable_baselines/gnn/__init__.py b/skdecide/hub/solver/stable_baselines/gnn/__init__.py new file mode 100644 index 0000000000..261a27e78b --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/__init__.py @@ -0,0 +1 @@ +from .ppo import GraphPPO diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/__init__.py b/skdecide/hub/solver/stable_baselines/gnn/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/buffers.py b/skdecide/hub/solver/stable_baselines/gnn/common/buffers.py new file mode 100644 index 0000000000..3af10e3998 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/common/buffers.py @@ -0,0 +1,253 @@ +from collections.abc import Generator +from typing import Optional, TypeVar, Union + +import numpy as np +import torch as th +import torch_geometric as thg +from gymnasium import spaces +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.type_aliases import RolloutBufferSamples +from stable_baselines3.common.utils import get_device +from stable_baselines3.common.vec_env import VecNormalize + +from .utils import copy_graph_instance, graph_obs_to_thg_data + + +def get_obs_shape( + observation_space: spaces.Space, +) -> Union[tuple[int, ...], dict[str, tuple[int, ...]]]: + """ + Get the shape of the observation (useful for the buffers). + + :param observation_space: + :return: + """ + if isinstance(observation_space, spaces.Box): + return observation_space.shape + elif isinstance(observation_space, spaces.Discrete): + # Observation is an int + return (1,) + elif isinstance(observation_space, spaces.MultiDiscrete): + # Number of discrete features + return (int(len(observation_space.nvec)),) + elif isinstance(observation_space, spaces.MultiBinary): + # Number of binary features + return observation_space.shape + elif isinstance(observation_space, spaces.Graph): + # Will not be used + return observation_space.node_space.shape + observation_space.edge_space.shape + elif isinstance(observation_space, spaces.Dict): + return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc] + + else: + raise NotImplementedError( + f"{observation_space} observation space is not supported" + ) + + +class GraphRolloutBuffer(RolloutBuffer): + """Rollout buffer used in on-policy algorithms like A2C/PPO with graph observations. + + Handles cases where observation space is: + - a Graph space + - a Dict space whose subspaces includes a Graph space + + """ + + observations: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]] + tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] + + def __init__( + self, + buffer_size: int, + observation_space: Union[spaces.Graph, spaces.Dict], + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.buffer_size = buffer_size + self.observation_space = observation_space + self.action_space = action_space + self.obs_shape = get_obs_shape(observation_space) + self.action_dim = get_action_dim(action_space) + self.pos = 0 + self.full = False + self.device = get_device(device) + self.n_envs = n_envs + self.gae_lambda = gae_lambda + self.gamma = gamma + self.generator_ready = False + + self.reset() + + def reset(self) -> None: + assert isinstance( + self.observation_space, spaces.Graph + ), "GraphRolloutBuffer must be used with Graph obs space only" + super().reset() + self.observations = list() + + def add( + self, + obs: spaces.GraphInstance, + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + ) -> None: + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + self._add_obs(obs) + + # Same reshape, for actions + action = action.reshape((self.n_envs, self.action_dim)) + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def _add_obs(self, obs: list[spaces.GraphInstance]) -> None: + self.observations.append([copy_graph_instance(g) for g in obs]) + + def _swap_and_flatten_obs(self) -> None: + self.observations = _swap_and_flatten_nested_list(self.observations) + + def get( + self, batch_size: Optional[int] = None + ) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + self._swap_and_flatten_obs() + for tensor in self.tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples( + self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None + ) -> RolloutBufferSamples: + observations = self._get_observations_samples(batch_inds) + data = ( + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + ) + return RolloutBufferSamples(observations, *tuple(map(self.to_torch, data))) + + def _get_observations_samples(self, batch_inds: np.ndarray) -> thg.data.Data: + return self._graphlist_to_torch(self.observations, batch_inds=batch_inds) + + def _graphlist_to_torch( + self, graph_list: list[spaces.GraphInstance], batch_inds: np.ndarray + ) -> thg.data.Data: + return thg.data.Batch.from_data_list( + [ + graph_obs_to_thg_data(graph_list[idx], device=self.device) + for idx in batch_inds + ] + ) + + +class DictGraphRolloutBuffer(GraphRolloutBuffer, DictRolloutBuffer): + + observations: dict[ + str, + Union[ + Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]], + np.ndarray, + ], + ] + obs_shape: dict[str, tuple[int, ...]] + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Dict, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.is_observation_subspace_graph: dict[str, bool] = { + k: isinstance(space, spaces.Graph) + for k, space in observation_space.spaces.items() + } + super().__init__( + buffer_size=buffer_size, + observation_space=observation_space, + action_space=action_space, + device=device, + gae_lambda=gae_lambda, + gamma=gamma, + n_envs=n_envs, + ) + + def reset(self) -> None: + super(GraphRolloutBuffer, self).reset() + for k, is_graph in self.is_observation_subspace_graph.items(): + if is_graph: + self.observations[k] = list() + + def _add_obs( + self, obs: dict[str, Union[np.ndarray, list[spaces.GraphInstance]]] + ) -> None: + for key in self.observations.keys(): + if self.is_observation_subspace_graph[key]: + self.observations[key].append( + [copy_graph_instance(g) for g in obs[key]] + ) + else: + obs_ = np.array(obs[key]) + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + def _swap_and_flatten_obs(self) -> None: + for key, obs in self.observations.items(): + if self.is_observation_subspace_graph[key]: + self.observations[key] = _swap_and_flatten_nested_list(obs) + else: + self.observations[key] = self.swap_and_flatten(obs) + + def _get_observations_samples( + self, batch_inds: np.ndarray + ) -> dict[str, Union[thg.data.Data, th.Tensor]]: + return { + k: self._graphlist_to_torch(obs, batch_inds=batch_inds) + if self.is_observation_subspace_graph[k] + else self.to_torch(obs[batch_inds]) + for k, obs in self.observations.items() + } + + +T = TypeVar("T") + + +def _swap_and_flatten_nested_list(obs: list[list[T]]) -> list[T]: + return [x for subobs in obs for x in subobs] diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/on_policy_algorithm.py b/skdecide/hub/solver/stable_baselines/gnn/common/on_policy_algorithm.py new file mode 100644 index 0000000000..eea7743c5b --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/common/on_policy_algorithm.py @@ -0,0 +1,170 @@ +from typing import Optional, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv +from stable_baselines3.common.vec_env import VecEnv + +from .buffers import DictGraphRolloutBuffer, GraphRolloutBuffer +from .utils import obs_as_tensor +from .vec_env.dummy_vec_env import wrap_graph_env + + +class GraphOnPolicyAlgorithm(OnPolicyAlgorithm): + """Base class for On-Policy algorithms (ex: A2C/PPO) with graph observations.""" + + def __init__( + self, + policy: Union[str, type[ActorCriticPolicy]], + env: GymEnv, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, + **kwargs, + ): + + # Use proper default rollout buffer class + if rollout_buffer_class is None: + if isinstance(env.observation_space, spaces.Graph): + rollout_buffer_class = GraphRolloutBuffer + elif isinstance(env.observation_space, spaces.Dict): + rollout_buffer_class = DictGraphRolloutBuffer + + # Use proper VecEnv wrapper for env with Graph spaces + env = wrap_graph_env(env) + if env.num_envs > 1: + raise NotImplementedError( + "GraphOnPolicyAlgorithm not implemented for real vectorized environment " + "(ie. with more than 1 wrapped environment)" + ) + + super().__init__( + policy=policy, + env=env, + rollout_buffer_class=rollout_buffer_class, + **kwargs, + ) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + This method is largely identical to the implementation found in the parent class. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_rollout_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + rollout_buffer.reset() + + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + while n_steps < n_rollout_steps: + if ( + self.use_sde + and self.sde_sample_freq > 0 + and n_steps % self.sde_sample_freq == 0 + ): + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + + actions, values, log_probs = self.policy(obs_tensor) + actions = actions.cpu().numpy() + + # Rescale and perform action + clipped_actions = actions + + if isinstance(self.action_space, spaces.Box): + if self.policy.squash_output: + # Unscale the actions to match env bounds + # if they were previously squashed (scaled in [-1, 1]) + clipped_actions = self.policy.unscale_action(clipped_actions) + else: + # Otherwise, clip the actions to avoid out of bound error + # as we are sampling from an unbounded Gaussian distribution + clipped_actions = np.clip( + actions, self.action_space.low, self.action_space.high + ) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if not callback.on_step(): + return False + + self._update_info_buffer(infos, dones) + n_steps += 1 + + if isinstance(self.action_space, spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor( + infos[idx]["terminal_observation"] + )[0] + with th.no_grad(): + terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, # type: ignore[arg-type] + actions, + rewards, + self._last_episode_starts, # type: ignore[arg-type] + values, + log_probs, + ) + self._last_obs = new_obs # type: ignore[assignment] + self._last_episode_starts = dones + + with th.no_grad(): + # Compute value for the last timestep + obs_tensor = obs_as_tensor(new_obs, self.device) + values = self.policy.predict_values(obs_tensor) # type: ignore[arg-type] + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.update_locals(locals()) + + callback.on_rollout_end() + + return True diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/policies.py b/skdecide/hub/solver/stable_baselines/gnn/common/policies.py new file mode 100644 index 0000000000..a7fb5a6f48 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/common/policies.py @@ -0,0 +1,204 @@ +import copy +import warnings +from typing import Any, Dict, Optional, Tuple, Union + +import gymnasium as gym +import numpy as np +import torch as th +import torch_geometric as thg +from stable_baselines3.common.distributions import Distribution +from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy +from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor +from stable_baselines3.common.type_aliases import Schedule + +from .preprocessing import preprocess_obs +from .torch_layers import CombinedFeaturesExtractor, GraphFeaturesExtractor +from .utils import ObsType, TorchObsType, is_vectorized_observation, obs_as_tensor + + +class GNNActorCriticPolicy(ActorCriticPolicy): + def __init__( + self, + observation_space: gym.spaces.Graph, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[list[Union[int, dict[str, list[int]]]]] = None, + activation_fn: type[th.nn.Module] = th.nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: type[BaseFeaturesExtractor] = GraphFeaturesExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__( + observation_space=observation_space, + action_space=action_space, + lr_schedule=lr_schedule, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + use_sde=use_sde, + log_std_init=log_std_init, + full_std=full_std, + use_expln=use_expln, + squash_output=squash_output, + features_extractor_class=features_extractor_class, + features_extractor_kwargs=features_extractor_kwargs, + share_features_extractor=share_features_extractor, + normalize_images=normalize_images, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + ) + + def extract_features( + self, + obs: thg.data.Data, + features_extractor: Optional[BaseFeaturesExtractor] = None, + ) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: + """ + Preprocess the observation if needed and extract features. + + :param obs: Observation + :param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used. + :return: The extracted features. If features extractor is not shared, returns a tuple with the + features for the actor and the features for the critic. + """ + preprocessed_obs = preprocess_obs( + obs, self.observation_space, normalize_images=self.normalize_images + ) + if self.share_features_extractor: + if features_extractor is None: + features_extractor = self.features_extractor + return features_extractor(preprocessed_obs) + else: + if features_extractor is not None: + warnings.warn( + "Provided features_extractor will be ignored because the features extractor is not shared.", + UserWarning, + ) + + pi_features = self.pi_features_extractor(preprocessed_obs) + vf_features = self.vf_features_extractor(preprocessed_obs) + return pi_features, vf_features + + def obs_to_tensor(self, observation: ObsType) -> tuple[TorchObsType, bool]: + vectorized_env = False + if isinstance(self.observation_space, gym.spaces.Graph): + vectorized_env = is_vectorized_observation( + observation, self.observation_space + ) + elif isinstance(observation, dict): + assert isinstance( + self.observation_space, gym.spaces.Dict + ), f"The observation provided is a dict but the obs space is {self.observation_space}" + # need to copy the dict as the dict in VecFrameStack will become a torch tensor + observation = copy.deepcopy(observation) + for key, obs in observation.items(): + obs_space = self.observation_space.spaces[key] + if isinstance(obs_space, gym.spaces.Graph): + vectorized_env = vectorized_env or is_vectorized_observation( + obs, obs_space + ) + else: + if is_image_space(obs_space): + obs_ = maybe_transpose(obs, obs_space) + else: + obs_ = np.array(obs) + vectorized_env = vectorized_env or is_vectorized_observation( + obs_, obs_space + ) + # Add batch dimension if needed + observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc] + else: + return super().obs_to_tensor(observation) + + obs_tensor = obs_as_tensor(observation, self.device) + return obs_tensor, vectorized_env + + def is_vectorized_observation( + self, observation: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> bool: + vectorized_env = False + if isinstance(observation, dict): + assert isinstance( + self.observation_space, gym.spaces.Dict + ), f"The observation provided is a dict but the obs space is {self.observation_space}" + for key, obs in observation.items(): + obs_space = self.observation_space.spaces[key] + vectorized_env = vectorized_env or is_vectorized_observation( + maybe_transpose(obs, obs_space), obs_space + ) + else: + vectorized_env = is_vectorized_observation( + maybe_transpose(observation, self.observation_space), + self.observation_space, + ) + return vectorized_env + + def get_distribution(self, obs: thg.data.Data) -> Distribution: + preprocessed_obs = preprocess_obs( + obs, self.observation_space, normalize_images=self.normalize_images + ) + features = self.pi_features_extractor(preprocessed_obs) + latent_pi = self.mlp_extractor.forward_actor(features) + return self._get_action_dist_from_latent(latent_pi) + + def predict_values(self, obs: thg.data.Data) -> th.Tensor: + preprocessed_obs = preprocess_obs( + obs, self.observation_space, normalize_images=self.normalize_images + ) + features = self.vf_features_extractor(preprocessed_obs) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) + + +class MultiInputGNNActorCriticPolicy(GNNActorCriticPolicy): + def __init__( + self, + observation_space: gym.spaces.Graph, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[list[Union[int, dict[str, list[int]]]]] = None, + activation_fn: type[th.nn.Module] = th.nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: type[ + BaseFeaturesExtractor + ] = CombinedFeaturesExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__( + observation_space=observation_space, + action_space=action_space, + lr_schedule=lr_schedule, + net_arch=net_arch, + activation_fn=activation_fn, + ortho_init=ortho_init, + use_sde=use_sde, + log_std_init=log_std_init, + full_std=full_std, + use_expln=use_expln, + squash_output=squash_output, + features_extractor_class=features_extractor_class, + features_extractor_kwargs=features_extractor_kwargs, + share_features_extractor=share_features_extractor, + normalize_images=normalize_images, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + ) diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/preprocessing.py b/skdecide/hub/solver/stable_baselines/gnn/common/preprocessing.py new file mode 100644 index 0000000000..dd38c3e7d2 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/common/preprocessing.py @@ -0,0 +1,38 @@ +import torch as th +import torch_geometric as thg +from gymnasium import spaces +from stable_baselines3.common.preprocessing import preprocess_obs as sb3_preprocess_obs + +from .utils import TorchObsType + + +def preprocess_obs( + obs: TorchObsType, + observation_space: spaces.Space, + normalize_images: bool = True, +) -> TorchObsType: + """Preprocess observation to be fed to a neural network. + + Wraps original sb3 preprocess_obs to catch graph obs. + + """ + if isinstance(observation_space, spaces.Dict): + # Do not modify by reference the original observation + assert isinstance(obs, dict), f"Expected dict, got {type(obs)}" + preprocessed_obs = {} + for key, _obs in obs.items(): + preprocessed_obs[key] = preprocess_obs( + _obs, observation_space[key], normalize_images=normalize_images + ) + return preprocessed_obs # type: ignore[return-value] + + assert isinstance( + obs, (th.Tensor, thg.data.Data) + ), f"Expecting a torch Tensor or torch geometric Data, but got {type(obs)}" + + if isinstance(observation_space, spaces.Graph): + return obs + else: + return sb3_preprocess_obs( + obs, observation_space=observation_space, normalize_images=normalize_images + ) diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/torch_layers.py b/skdecide/hub/solver/stable_baselines/gnn/common/torch_layers.py new file mode 100644 index 0000000000..18aa0d7612 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/common/torch_layers.py @@ -0,0 +1,172 @@ +from typing import Any, Optional, Union + +import gymnasium as gym +import numpy as np +import torch as th +import torch_geometric as thg +from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, NatureCNN +from torch import nn +from torch_geometric.nn import global_max_pool + + +class GraphFeaturesExtractor(BaseFeaturesExtractor): + """Graph feature extractor for Graph observation spaces. + + Will chain a gnn with a reduction layer to extract a fixed number of features. + The user can specify both the gnn and reduction layer. + + By default, we use: + - gnn: a 2-layers GCN + - reduction layer: global_max_pool + linear layer + relu + + Args: + observation_space: + features_dim: Number of extracted features + - If reduction_layer_class is given, should match the output of this network. + - If reduction_layer is None, will be used by the default network as its output dimension. + gnn_out_dim: dimension of the node embedding in gnn output + - If gnn is given, should not be None and should match the output of gnn + - If gnn is not given, will be used to generate it. By default, gnn_out_dim = 2 * features_dim + gnn_class: GNN network class (for instance chosen from `torch_geometric.nn.models` used to embed the graph observations) + gnn_kwargs: used by `gnn_class.__init__()`. Without effect if `gnn_class` is None. + reduction_layer_class: network class to be plugged after the gnn to get a fixed number of features. + reduction_layer_kwargs: used by `reduction_layer_class.__init__()`. Without effect if `reduction_layer_class` is None. + + """ + + def __init__( + self, + observation_space: gym.spaces.Graph, + features_dim: int = 64, + gnn_out_dim: Optional[int] = None, + gnn_class: Optional[type[nn.Module]] = None, + gnn_kwargs: Optional[dict[str, Any]] = None, + reduction_layer_class: Optional[type[nn.Module]] = None, + reduction_layer_kwargs: Optional[dict[str, Any]] = None, + ): + + super().__init__(observation_space, features_dim=features_dim) + + if gnn_out_dim is None: + if gnn_class is None: + gnn_out_dim = 2 * features_dim + else: + raise ValueError( + "`gnn_out_dim` cannot be None if `gnn` is not None, " + "and should match `gnn` output." + ) + + if gnn_class is None: + node_features_dim = int(np.prod(observation_space.node_space.shape)) + self.gnn = thg.nn.models.GCN( + in_channels=node_features_dim, + hidden_channels=gnn_out_dim, + num_layers=2, + dropout=0.2, + ) + else: + if gnn_kwargs is None: + gnn_kwargs = {} + self.gnn = gnn_class(**gnn_kwargs) + + if reduction_layer_class is None: + self.reduction_layer = _DefaultReductionLayer( + gnn_out_dim=gnn_out_dim, features_dim=features_dim + ) + else: + if reduction_layer_kwargs is None: + reduction_layer_kwargs = {} + self.reduction_layer = reduction_layer_class(**reduction_layer_kwargs) + + def forward(self, observations: thg.data.Data) -> th.Tensor: + x, edge_index, edge_attr, batch = ( + observations.x, + observations.edge_index, + observations.edge_attr, + observations.batch, + ) + # construct edge weights, for GNNs needing it, as the first edge feature + edge_weight = edge_attr[:, 0] + h = self.gnn( + x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr + ) + embedded_observations = thg.data.Data( + x=h, edge_index=edge_index, edge_attr=edge_attr, batch=batch + ) + h = self.reduction_layer(embedded_observations=embedded_observations) + return h + + +class _DefaultReductionLayer(nn.Module): + def __init__(self, gnn_out_dim: int, features_dim: int): + super().__init__() + self.gnn_out_dim = gnn_out_dim + self.features_dim = features_dim + self.linear_layer = nn.Linear(gnn_out_dim, features_dim) + + def forward(self, embedded_observations: thg.data.Data) -> th.Tensor: + x, edge_index, batch = ( + embedded_observations.x, + embedded_observations.edge_index, + embedded_observations.batch, + ) + h = global_max_pool(x, batch) + h = self.linear_layer(h).relu() + return h + + +class CombinedFeaturesExtractor(BaseFeaturesExtractor): + """Combined features extractor for Dict observation spaces, subspaces potentially including Graph spaces. + + Builds a features extractor for each key of the space. Input from each space + is fed through a separate submodule (CNN or MLP, depending on input shape), + the output features are concatenated. + + Args: + observation_space: + cnn_kwargs: to be passed to NatureCNN extractor. + `cnn_kwargs["normalized_image"] is used to check if the space is an image space + (see `stable_baselines3.common.torch_layers.NatureCNN`) + graph_kwargs: to be passed to GraphFeaturesExtractor extractor + + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + cnn_kwargs: Optional[dict[str, Any]] = None, + graph_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + if cnn_kwargs is None: + cnn_kwargs = {} + if graph_kwargs is None: + graph_kwargs = {} + normalized_image = cnn_kwargs.get("normalized_image", False) + + extractors: dict[str, nn.Module] = {} + total_concat_size = 0 + for key, subspace in observation_space.spaces.items(): + if isinstance(subspace, gym.spaces.Graph): + extractors[key] = GraphFeaturesExtractor(subspace, **graph_kwargs) + total_concat_size += extractors[key].features_dim + elif is_image_space(subspace, normalized_image=normalized_image): + extractors[key] = NatureCNN(subspace, **cnn_kwargs) + total_concat_size += extractors[key].features_dim + else: + # The observation key is a vector, flatten it if needed + extractors[key] = nn.Flatten() + total_concat_size += get_flattened_obs_dim(subspace) + + # call __init__ before assigning attributes (but after computing total_concat_size) + super().__init__(observation_space, features_dim=total_concat_size) + self.extractors = nn.ModuleDict(extractors) + + def forward( + self, observations: dict[str, Union[th.Tensor, thg.data.Data]] + ) -> th.Tensor: + encoded_tensor_list = [] + + for key, extractor in self.extractors.items(): + encoded_tensor_list.append(extractor(observations[key])) + return th.cat(encoded_tensor_list, dim=1) diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/utils.py b/skdecide/hub/solver/stable_baselines/gnn/common/utils.py new file mode 100644 index 0000000000..e0abcef987 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/common/utils.py @@ -0,0 +1,94 @@ +from typing import Union + +import gymnasium as gym +import numpy as np +import stable_baselines3.common.utils +import torch as th +import torch_geometric as thg + +SubObsType = Union[np.ndarray, gym.spaces.GraphInstance, list[gym.spaces.GraphInstance]] +ObsType = Union[SubObsType, dict[str, SubObsType]] +TorchSubObsType = Union[th.Tensor, thg.data.Data] +TorchObsType = Union[TorchSubObsType, dict[str, TorchSubObsType]] + + +def copy_graph_instance(g: gym.spaces.GraphInstance) -> gym.spaces.GraphInstance: + return gym.spaces.GraphInstance( + nodes=np.copy(g.nodes), edges=np.copy(g.edges), edge_links=np.copy(g.edge_links) + ) + + +def copy_np_array_or_list_of_graph_instances( + obs: Union[np.ndarray, list[gym.spaces.GraphInstance]] +) -> Union[np.ndarray, list[gym.spaces.GraphInstance]]: + if isinstance(obs[0], gym.spaces.GraphInstance): + return [copy_graph_instance(g) for g in obs] + else: + return np.copy(obs) + + +def graph_obs_to_thg_data( + obs: gym.spaces.GraphInstance, device: th.device +) -> thg.data.Data: + # Node features + flatten_node_features = obs.nodes.reshape((len(obs.nodes), -1)) + x = th.tensor(flatten_node_features).float() + # Edge features + if obs.edges is None: + edge_attr = None + else: + flatten_edge_features = obs.edges.reshape((len(obs.edges), -1)) + edge_attr = th.tensor(flatten_edge_features).float() + edge_index = th.tensor(obs.edge_links, dtype=th.long).t().contiguous().view(2, -1) + return thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr).to(device) + + +def obs_as_tensor( + obs: ObsType, + device: th.device, +) -> TorchObsType: + """ + Moves the observation to the given device. + + Args: + obs: + device: PyTorch device + + Returns: + PyTorch tensor of the observation on a desired device. + + """ + if isinstance(obs, gym.spaces.GraphInstance): + return graph_obs_to_thg_data(obs, device=device) + elif isinstance(obs, list) and isinstance(obs[0], gym.spaces.GraphInstance): + if len(obs) > 1: + raise NotImplementedError( + "Not implemented for real vectorized environment " + "(ie. with more than 1 wrapped environment)" + ) + return graph_obs_to_thg_data(obs[0], device=device) + elif isinstance(obs, np.ndarray): + return th.as_tensor(obs, device=device) + elif isinstance(obs, dict): + return {key: obs_as_tensor(_obs, device=device) for (key, _obs) in obs.items()} + else: + raise Exception(f"Unrecognized type of observation {type(obs)}") + + +def is_vectorized_observation( + observation: SubObsType, observation_space: gym.spaces.Space +) -> bool: + """ + For every observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + if isinstance(observation_space, gym.spaces.Graph): + return isinstance(observation_space, list) + else: + return stable_baselines3.common.utils.is_vectorized_observation( + observation=observation, observation_space=observation_space + ) diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/vec_env/__init__.py b/skdecide/hub/solver/stable_baselines/gnn/common/vec_env/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/vec_env/dummy_vec_env.py b/skdecide/hub/solver/stable_baselines/gnn/common/vec_env/dummy_vec_env.py new file mode 100644 index 0000000000..324a075b44 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/common/vec_env/dummy_vec_env.py @@ -0,0 +1,78 @@ +from collections import OrderedDict +from typing import Callable, List, Union + +import gymnasium as gym +import numpy as np +from stable_baselines3.common.env_util import is_wrapped +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.type_aliases import GymEnv +from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs +from stable_baselines3.common.vec_env.util import dict_to_obs + +from ..utils import copy_np_array_or_list_of_graph_instances + +EnvSubObs = Union[np.ndarray, list[gym.spaces.GraphInstance]] +VecEnvObs = Union[EnvSubObs, dict[str, EnvSubObs], tuple[EnvSubObs, ...]] + + +class GraphDummyVecEnv(DummyVecEnv): + def __init__(self, env_fns: List[Callable[[], gym.Env]]): + super().__init__(env_fns) + # replace buffers for graph spaces by lists + obs_space = self.envs[0].observation_space + if isinstance(obs_space, gym.spaces.Graph): + self.buf_obs[None] = [None for _ in range(self.num_envs)] + elif isinstance(obs_space, gym.spaces.Dict): + for k, space in obs_space.spaces.items(): + if isinstance(space, gym.spaces.Graph): + self.buf_obs[k] = [None for _ in range(self.num_envs)] + + def _obs_from_buf(self) -> VecEnvObs: + return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) + + +def copy_obs_dict(obs: dict[str, EnvSubObs]) -> dict[str, EnvSubObs]: + """ + Deep-copy a dict of numpy arrays. + + :param obs: a dict of numpy arrays. + :return: a dict of copied numpy arrays. + """ + assert isinstance( + obs, OrderedDict + ), f"unexpected type for observations '{type(obs)}'" + return OrderedDict( + [(k, copy_np_array_or_list_of_graph_instances(v)) for k, v in obs.items()] + ) + + +def wrap_graph_env( + env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True +) -> VecEnv: + """Wrap environment with the appropriate wrappers if needed. + + :param env: + :param verbose: Verbosity level: 0 for no output, 1 for indicating wrappers used + :param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible. + :return: The wrapped environment. + """ + if not isinstance(env, VecEnv): + if not is_wrapped(env, Monitor) and monitor_wrapper: + if verbose >= 1: + print("Wrapping the env with a `Monitor` wrapper") + env = Monitor(env) + if verbose >= 1: + print("Wrapping the env in a DummyVecEnv.") + # patch: add dummy shape and dtype to graph obs space to avoid issues + observation_space = env.observation_space + if isinstance(observation_space, gym.spaces.Graph): + observation_space._shape = (0,) + observation_space.dtype = np.float_ + elif isinstance(observation_space, gym.spaces.Dict): + for subspace in observation_space.spaces.values(): + if isinstance(subspace, gym.spaces.Graph): + subspace._shape = (0,) + subspace.dtype = np.float_ + env = GraphDummyVecEnv([lambda: env]) # type: ignore[list-item, return-value] + return env diff --git a/skdecide/hub/solver/stable_baselines/gnn/ppo/__init__.py b/skdecide/hub/solver/stable_baselines/gnn/ppo/__init__.py new file mode 100644 index 0000000000..597dab7c40 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/ppo/__init__.py @@ -0,0 +1,5 @@ +from ..common.policies import GNNActorCriticPolicy, MultiInputGNNActorCriticPolicy +from .ppo import GraphPPO + +GraphInputPolicy = GNNActorCriticPolicy +MultiInputPolicy = MultiInputGNNActorCriticPolicy diff --git a/skdecide/hub/solver/stable_baselines/gnn/ppo/ppo.py b/skdecide/hub/solver/stable_baselines/gnn/ppo/ppo.py new file mode 100644 index 0000000000..9afac8bde1 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/ppo/ppo.py @@ -0,0 +1,14 @@ +from typing import ClassVar + +from stable_baselines3 import PPO +from stable_baselines3.common.policies import BasePolicy + +from ..common.on_policy_algorithm import GraphOnPolicyAlgorithm +from ..common.policies import GNNActorCriticPolicy, MultiInputGNNActorCriticPolicy + + +class GraphPPO(GraphOnPolicyAlgorithm, PPO): + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { + "GraphInputPolicy": GNNActorCriticPolicy, + "MultiInputPolicy": MultiInputGNNActorCriticPolicy, + } diff --git a/skdecide/hub/solver/stable_baselines/stable_baselines.py b/skdecide/hub/solver/stable_baselines/stable_baselines.py index 7f4cfa7daf..78abcf7795 100644 --- a/skdecide/hub/solver/stable_baselines/stable_baselines.py +++ b/skdecide/hub/solver/stable_baselines/stable_baselines.py @@ -7,6 +7,7 @@ from collections.abc import Callable from typing import Any, Optional, Union +import gymnasium as gym from discrete_optimization.generic_tools.hyperparameters.hyperparameter import ( CategoricalHyperparameter, FloatHyperparameter, @@ -17,7 +18,6 @@ from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import MaybeCallback -from stable_baselines3.common.vec_env import DummyVecEnv from skdecide import Domain, Solver from skdecide.builders.domain import ( @@ -132,6 +132,9 @@ def __init__( ent_coef_log = kwargs.pop("ent_coef_log") kwargs["ent_coef"] = 10**ent_coef_log + def _as_gymnasium_env(self, domain: Domain) -> gym.Env: + return as_gymnasium_env(domain) + @classmethod def _check_domain_additional(cls, domain: Domain) -> bool: return isinstance(domain.get_action_space(), GymSpace) and isinstance( @@ -146,9 +149,7 @@ def _solve(self) -> None: self, "_algo" ): # reuse algo if possible (enables further learning) domain = self._domain_factory() - env = DummyVecEnv( - [lambda: AsGymnasiumEnv(domain)] - ) # the algorithms require a vectorized environment to run + env = self._as_gymnasium_env(domain) self._algo = self._algo_class( self._baselines_policy, env, **self._algo_kwargs ) @@ -182,8 +183,7 @@ def _save(self, path: str) -> None: def _load(self, path: str): domain = self._domain_factory() - env = DummyVecEnv([lambda: AsGymnasiumEnv(domain)]) - self._algo = self._algo_class.load(path, env=env) + self._algo = self._algo_class.load(path, env=self._as_gymnasium_env(domain)) self._init_algo(domain) def _init_algo(self, domain: D): @@ -209,3 +209,12 @@ def __init__( def _on_step(self) -> bool: return not self.callback(self.solver) + + +def as_gymnasium_env(domain: Domain) -> gym.Env: + """Wraps the domain into a gymnasium env. + + To be fed to sb3 algorithms. + + """ + return AsGymnasiumEnv(domain=domain) diff --git a/tests/solvers/python/test_gnn_sb3.py b/tests/solvers/python/test_gnn_sb3.py new file mode 100644 index 0000000000..b18e8f8900 --- /dev/null +++ b/tests/solvers/python/test_gnn_sb3.py @@ -0,0 +1,459 @@ +from typing import Any, Optional + +import numpy as np +import numpy.typing as npt +import torch as th +import torch_geometric as thg +from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv +from gymnasium.spaces import Box, Discrete, Graph, GraphInstance +from pytest_cases import fixture, fixture_union, param_fixture +from torch_geometric.nn import global_add_pool + +from skdecide.builders.domain import Renderable, UnrestrictedActions +from skdecide.core import Space, TransitionOutcome, Value +from skdecide.domains import DeterministicPlanningDomain +from skdecide.hub.domain.gym import GymDomain +from skdecide.hub.domain.maze import Maze +from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State +from skdecide.hub.solver.stable_baselines import StableBaseline +from skdecide.hub.solver.stable_baselines.gnn import GraphPPO +from skdecide.hub.solver.stable_baselines.gnn.common.torch_layers import ( + GraphFeaturesExtractor, +) +from skdecide.hub.space.gym import DictSpace, GymSpace, ListSpace +from skdecide.utils import rollout + +# JSP graph env + + +class D(GymDomain): + T_state = GraphInstance # Type of states + T_observation = T_state # Type of observations + T_event = int # Type of events + T_value = float # Type of transition values (rewards or costs) + T_info = None # Type of additional information in environment outcome + + +class GraphJspDomain(D): + _gym_env: DisjunctiveGraphJspEnv + + def __init__(self, gym_env): + GymDomain.__init__(self, gym_env=gym_env) + if self._gym_env.normalize_observation_space: + self.n_nodes_features = gym_env.n_machines + 1 + else: + self.n_nodes_features = 2 + + def _state_step( + self, action: D.T_event + ) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]: + outcome = super()._state_step(action=action) + outcome.state = self._np_state2graph_state(outcome.state) + return outcome + + def _get_applicable_actions_from( + self, memory: D.T_memory[D.T_state] + ) -> D.T_agent[Space[D.T_event]]: + return ListSpace(np.nonzero(self._gym_env.valid_action_mask())[0]) + + def _is_applicable_action_from( + self, action: D.T_agent[D.T_event], memory: D.T_memory[D.T_state] + ) -> bool: + return self._gym_env.valid_action_mask()[action] + + def _state_reset(self) -> D.T_state: + return self._np_state2graph_state(super()._state_reset()) + + def _get_observation_space_(self) -> Space[D.T_observation]: + if self._gym_env.normalize_observation_space: + original_graph_space = Graph( + node_space=Box( + low=0.0, + high=1.0, + shape=(self.n_nodes_features,), + dtype=np.float_, + ), + edge_space=Box(low=0, high=1.0, dtype=np.float_), + ) + + else: + original_graph_space = Graph( + node_space=Box( + low=np.array([0, 0]), + high=np.array( + [ + self._gym_env.n_machines, + self._gym_env.longest_processing_time, + ] + ), + dtype=np.int_, + ), + edge_space=Box( + low=0, high=self._gym_env.longest_processing_time, dtype=np.int_ + ), + ) + return GymSpace(original_graph_space) + + def _np_state2graph_state(self, np_state: np.array) -> GraphInstance: + if not self._gym_env.normalize_observation_space: + np_state = np_state.astype(np.int_) + + nodes = np_state[:, -self.n_nodes_features :] + adj = np_state[:, : -self.n_nodes_features] + edge_starts_ends = adj.nonzero() + edge_links = np.transpose(edge_starts_ends) + edges = adj[edge_starts_ends][:, None] + + return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links) + + def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: + return self._gym_env.render(**kwargs) + + +class D(GraphJspDomain): + T_state = dict[str, Any] + + +class MultiInputGraphJspDomain(D): + def _get_observation_space_(self) -> Space[D.T_observation]: + return DictSpace( + spaces=dict( + graph=super()._get_observation_space_(), + static=Box(low=0.0, high=1.0), + ) + ) + + def _state_step( + self, action: D.T_event + ) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]: + transition = super()._state_step(action) + transition.state = dict(graph=transition.state, static=np.array([0.5])) + return transition + + def _state_reset(self) -> D.T_state: + return dict(graph=super()._state_reset(), static=np.array([0.5])) + + +jsp = np.array( + [ + [ + [0, 1, 2], # machines for job 0 + [0, 2, 1], # machines for job 1 + [0, 1, 2], # machines for job 2 + ], + [ + [3, 2, 2], # task durations of job 0 + [2, 1, 4], # task durations of job 1 + [0, 4, 3], # task durations of job 2 + ], + ] +) + + +class D(DeterministicPlanningDomain, UnrestrictedActions, Renderable): + T_state = GraphInstance # Type of states + T_observation = T_state # Type of observations + T_event = Action # Type of events + T_value = float # Type of transition values (rewards or costs) + T_predicate = bool # Type of logical checks + T_info = ( + None # Type of additional information given as part of an environment outcome + ) + + +class GraphMaze(D): + def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): + self.discrete_features = discrete_features + self.maze_domain = Maze(maze_str=maze_str) + np_wall = np.array(self.maze_domain._maze) + np_y = np.array( + [ + [(i) for j in range(self.maze_domain._num_cols)] + for i in range(self.maze_domain._num_rows) + ] + ) + np_x = np.array( + [ + [(j) for j in range(self.maze_domain._num_cols)] + for i in range(self.maze_domain._num_rows) + ] + ) + walls = np_wall.ravel() + coords = [i for i in zip(np_y.ravel(), np_x.ravel())] + np_node_id = np.reshape(range(len(walls)), np_wall.shape) + edge_links = [] + edges = [] + for i in range(self.maze_domain._num_rows): + for j in range(self.maze_domain._num_cols): + current_coord = (i, j) + if i > 0: + next_coord = (i - 1, j) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if i < self.maze_domain._num_rows - 1: + next_coord = (i + 1, j) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if j > 0: + next_coord = (i, j - 1) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if j < self.maze_domain._num_cols - 1: + next_coord = (i, j + 1) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + self.edges = np.array(edges) + self.edge_links = np.array(edge_links) + self.walls = walls + self.node_ids = np_node_id + self.coords = coords + + def _mazestate2graph(self, state: State) -> GraphInstance: + x, y = state + agent_presence = np.zeros(self.walls.shape, dtype=self.walls.dtype) + agent_presence[self.node_ids[y, x]] = 1 + nodes = np.stack([self.walls, agent_presence], axis=-1) + if self.discrete_features: + return GraphInstance( + nodes=nodes, edges=self.edges, edge_links=self.edge_links + ) + else: + return GraphInstance( + nodes=nodes, edges=self.edges[:, None], edge_links=self.edge_links + ) + + def _graph2mazestate(self, graph: GraphInstance) -> State: + y, x = self.coords[graph.nodes[:, 1].nonzero()[0][0]] + return State(x=x, y=y) + + def _is_terminal(self, state: D.T_state) -> D.T_predicate: + return self.maze_domain._is_terminal(self._graph2mazestate(state)) + + def _get_next_state(self, memory: D.T_state, action: D.T_event) -> D.T_state: + maze_memory = self._graph2mazestate(memory) + maze_next_state = self.maze_domain._get_next_state( + memory=maze_memory, action=action + ) + return self._mazestate2graph(maze_next_state) + + def _get_transition_value( + self, + memory: D.T_state, + action: D.T_event, + next_state: Optional[D.T_state] = None, + ) -> Value[D.T_value]: + maze_memory = self._graph2mazestate(memory) + if next_state is None: + maze_next_state = None + else: + maze_next_state = self._graph2mazestate(next_state) + return self.maze_domain._get_transition_value( + memory=maze_memory, action=action, next_state=maze_next_state + ) + + def _get_action_space_(self) -> Space[D.T_event]: + return self.maze_domain._get_action_space_() + + def _get_goals_(self) -> Space[D.T_observation]: + return ListSpace([self._mazestate2graph(self.maze_domain._goal)]) + + def _is_goal( + self, observation: D.T_agent[D.T_observation] + ) -> D.T_agent[D.T_predicate]: + return self.maze_domain._is_goal(self._graph2mazestate(observation)) + + def _get_initial_state_(self) -> D.T_state: + return self._mazestate2graph(self.maze_domain._get_initial_state_()) + + def _get_observation_space_(self) -> Space[D.T_observation]: + if self.discrete_features: + return GymSpace( + Graph( + node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), + edge_space=Discrete(2), + ) + ) + else: + return GymSpace( + Graph( + node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), + edge_space=Box(low=0, high=1, shape=(1,), dtype=self.edges.dtype), + ) + ) + + def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any: + maze_memory = self._graph2mazestate(memory) + self.maze_domain._render_from(memory=maze_memory, **kwargs) + + +discrete_features = param_fixture("discrete_features", [False, True]) + + +@fixture +def maze_domain_factory(discrete_features): + return lambda: GraphMaze(discrete_features=discrete_features) + + +@fixture +def jsp_domain_factory(): + return lambda: GraphJspDomain( + gym_env=DisjunctiveGraphJspEnv( + jps_instance=jsp, + perform_left_shift_if_possible=True, + normalize_observation_space=False, + flat_observation_space=False, + action_mode="task", + ) + ) + + +@fixture +def jsp_dict_domain_factory(): + return lambda: MultiInputGraphJspDomain( + gym_env=DisjunctiveGraphJspEnv( + jps_instance=jsp, + perform_left_shift_if_possible=True, + normalize_observation_space=False, + flat_observation_space=False, + action_mode="task", + ) + ) + + +domain_factory = fixture_union( + "domain_factory", [maze_domain_factory, jsp_domain_factory] +) + + +def test_observation_space(domain_factory): + domain = domain_factory() + assert domain.reset() in domain.get_observation_space() + + +def test_ppo(domain_factory): + with StableBaseline( + domain_factory=domain_factory, + algo_class=GraphPPO, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_ppo_user_gnn(domain_factory): + domain = domain_factory() + node_features_dim = int( + np.prod(domain.get_observation_space().unwrapped().node_space.shape) + ) + with StableBaseline( + domain_factory=domain_factory, + algo_class=GraphPPO, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, + policy_kwargs=dict( + features_extractor_class=GraphFeaturesExtractor, + features_extractor_kwargs=dict( + gnn_class=thg.nn.models.GAT, + gnn_kwargs=dict( + in_channels=node_features_dim, + hidden_channels=64, + num_layers=2, + dropout=0.2, + ), + gnn_out_dim=64, + features_dim=64, + ), + ), + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +class MyReductionLayer(th.nn.Module): + def __init__(self, gnn_out_dim: int, features_dim: int): + super().__init__() + self.gnn_out_dim = gnn_out_dim + self.features_dim = features_dim + self.linear_layer = th.nn.Linear(gnn_out_dim, features_dim) + + def forward(self, embedded_observations: thg.data.Data) -> th.Tensor: + x, edge_index, batch = ( + embedded_observations.x, + embedded_observations.edge_index, + embedded_observations.batch, + ) + h = global_add_pool(x, batch) + h = self.linear_layer(h).relu() + return h + + +def test_ppo_user_reduction_layer(domain_factory): + with StableBaseline( + domain_factory=domain_factory, + algo_class=GraphPPO, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, + policy_kwargs=dict( + features_extractor_class=GraphFeaturesExtractor, + features_extractor_kwargs=dict( + gnn_out_dim=128, + features_dim=64, + reduction_layer_class=MyReductionLayer, + reduction_layer_kwargs=dict( + gnn_out_dim=128, + features_dim=64, + ), + ), + ), + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_dict_ppo(jsp_dict_domain_factory): + domain_factory = jsp_dict_domain_factory + with StableBaseline( + domain_factory=domain_factory, + algo_class=GraphPPO, + baselines_policy="MultiInputPolicy", + learn_config={"total_timesteps": 100}, + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + )