diff --git a/skdecide/hub/solver/stable_baselines/gnn/a2c/__init__.py b/skdecide/hub/solver/stable_baselines/gnn/a2c/__init__.py new file mode 100644 index 0000000000..6b6d84f6c3 --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/a2c/__init__.py @@ -0,0 +1,5 @@ +from ..common.policies import GNNActorCriticPolicy, MultiInputGNNActorCriticPolicy +from .a2c import GraphA2C + +GraphInputPolicy = GNNActorCriticPolicy +MultiInputPolicy = MultiInputGNNActorCriticPolicy diff --git a/skdecide/hub/solver/stable_baselines/gnn/a2c/a2c.py b/skdecide/hub/solver/stable_baselines/gnn/a2c/a2c.py new file mode 100644 index 0000000000..ecead362fd --- /dev/null +++ b/skdecide/hub/solver/stable_baselines/gnn/a2c/a2c.py @@ -0,0 +1,14 @@ +from typing import ClassVar + +from stable_baselines3 import A2C +from stable_baselines3.common.policies import BasePolicy + +from ..common.on_policy_algorithm import GraphOnPolicyAlgorithm +from ..common.policies import GNNActorCriticPolicy, MultiInputGNNActorCriticPolicy + + +class GraphA2C(GraphOnPolicyAlgorithm, A2C): + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { + "GraphInputPolicy": GNNActorCriticPolicy, + "MultiInputPolicy": MultiInputGNNActorCriticPolicy, + } diff --git a/tests/solvers/python/test_gnn_sb3.py b/tests/solvers/python/test_gnn_sb3.py index b7d81b7464..72029fb16c 100644 --- a/tests/solvers/python/test_gnn_sb3.py +++ b/tests/solvers/python/test_gnn_sb3.py @@ -16,6 +16,7 @@ from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State from skdecide.hub.solver.stable_baselines import StableBaseline from skdecide.hub.solver.stable_baselines.gnn import GraphPPO +from skdecide.hub.solver.stable_baselines.gnn.a2c import GraphA2C from skdecide.hub.solver.stable_baselines.gnn.common.torch_layers import ( GraphFeaturesExtractor, ) @@ -516,3 +517,41 @@ def test_dict_maskable_ppo(jsp_dict_domain_factory): render=False, use_applicable_actions=True, ) + + +def test_dict_a2c(jsp_dict_domain_factory): + domain_factory = jsp_dict_domain_factory + with StableBaseline( + domain_factory=domain_factory, + algo_class=GraphA2C, + baselines_policy="MultiInputPolicy", + learn_config={"total_timesteps": 100}, + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_a2c(jsp_domain_factory): + domain_factory = jsp_domain_factory + with StableBaseline( + domain_factory=domain_factory, + algo_class=GraphA2C, + baselines_policy="GraphInputPolicy", + learn_config={"total_timesteps": 100}, + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + )