Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GraphA2C algo, ie sb3 A2C with GNN feature extraction #451

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions skdecide/hub/solver/stable_baselines/gnn/a2c/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ..common.policies import GNNActorCriticPolicy, MultiInputGNNActorCriticPolicy
from .a2c import GraphA2C

GraphInputPolicy = GNNActorCriticPolicy
MultiInputPolicy = MultiInputGNNActorCriticPolicy
14 changes: 14 additions & 0 deletions skdecide/hub/solver/stable_baselines/gnn/a2c/a2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import ClassVar

from stable_baselines3 import A2C
from stable_baselines3.common.policies import BasePolicy

from ..common.on_policy_algorithm import GraphOnPolicyAlgorithm
from ..common.policies import GNNActorCriticPolicy, MultiInputGNNActorCriticPolicy


class GraphA2C(GraphOnPolicyAlgorithm, A2C):
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"GraphInputPolicy": GNNActorCriticPolicy,
"MultiInputPolicy": MultiInputGNNActorCriticPolicy,
}
39 changes: 39 additions & 0 deletions tests/solvers/python/test_gnn_sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn import GraphPPO
from skdecide.hub.solver.stable_baselines.gnn.a2c import GraphA2C
from skdecide.hub.solver.stable_baselines.gnn.common.torch_layers import (
GraphFeaturesExtractor,
)
Expand Down Expand Up @@ -516,3 +517,41 @@ def test_dict_maskable_ppo(jsp_dict_domain_factory):
render=False,
use_applicable_actions=True,
)


def test_dict_a2c(jsp_dict_domain_factory):
domain_factory = jsp_dict_domain_factory
with StableBaseline(
domain_factory=domain_factory,
algo_class=GraphA2C,
baselines_policy="MultiInputPolicy",
learn_config={"total_timesteps": 100},
) as solver:

solver.solve()
rollout(
domain=domain_factory(),
solver=solver,
max_steps=100,
num_episodes=1,
render=False,
)


def test_a2c(jsp_domain_factory):
domain_factory = jsp_domain_factory
with StableBaseline(
domain_factory=domain_factory,
algo_class=GraphA2C,
baselines_policy="GraphInputPolicy",
learn_config={"total_timesteps": 100},
) as solver:

solver.solve()
rollout(
domain=domain_factory(),
solver=solver,
max_steps=100,
num_episodes=1,
render=False,
)
Loading