Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a sb3 algo + policy for domains with graph observations #441

Merged
merged 1 commit into from
Dec 16, 2024

Conversation

nhuet
Copy link
Contributor

@nhuet nhuet commented Nov 19, 2024

  • we reuse our stable_baselines3 wrapper
  • the policy is extracting features from the graph with a GNN
  • the GNN is using pytorch-geometric
  • We subclass
    • ActorCriticPolicy:
      • feature extractor = gnn
      • custom conversion of observation to torch to convert into torch_geometric.data.Data
    • PPO to handle properly
      • observation conversion
      • rollout buffer
  • Current limitations:
    • we extract a fixed number of features (independent of edge/node numbers) for now as we end with a feature reduction layer connected to a classic mlp (not knowning anything about the current graph structure)
  • User input: the user can define (and default choices are made else)
    • the gnn (default to a 2 layers GCN), taking as inputs w.r.t torch_geometric conventions:
      • x: nodes features
      • edge_index: edge indices or sparse transposed adjency matrix
      • edge_attr (optional): edges features
      • edge_weight (optional): edge weights (taken from first dimension of edge_attr)
    • the feature reduction layer from the gnn output to the fixed number of features (default to global_max_pool + linear layer + relu)

We also introduce a multiinput policy to take into account (for instance) static graph features. The observation space in that case is a DictSpace whose subspaces can contain some Graph spaces.

@nhuet nhuet force-pushed the gnn-sb3 branch 9 times, most recently from ba35de0 to 204c1ed Compare November 26, 2024 16:20
@nhuet nhuet force-pushed the gnn-sb3 branch 3 times, most recently from 993b819 to ecfd289 Compare December 5, 2024 09:18
- we reuse our stable_baselines3 wrapper
- the policy is extracting features from the graph with a GNN
- the GNN is using pytorch-geometric
- We subclass
  - ActorCriticPolicy:
    - feature extractor = gnn
    - custom conversion of observation to torch to convert into
      torch_geometric.data.Data
  - PPO to handle properly
    - observation conversion
    - rollout buffer
- Current limitations:
  - we extract a fixed number of features (independent of edge/node
    numbers) for now as we end with a feature reduction layer connected
    to a classic mlp (not knowning anything about the current graph structure)
- User input: the user can define (and default choices are made else)
  - the gnn (default to a 2 layers GCN), taking as inputs w.r.t torch_geometric conventions:
    - x: nodes features
    - edge_index: edge indices or sparse transposed adjency matrix
    - edge_attr (optional): edges features
    - edge_weight (optional): edge weights (taken from first dimension
      of edge_attr)
  - the feature reduction layer from the gnn output to the fixed number of features
    (default to global_max_pool + linear layer + relu)

We also introduce a multiinput policy to take into account static graph
features. The observation space is a DictSpace whose subspaces can
contain some Graph spaces.
Copy link
Collaborator

@neo-alex neo-alex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great addition to the library, thank you! LGTM

@neo-alex neo-alex merged commit ab49ecb into airbus:master Dec 16, 2024
27 of 33 checks passed
@nhuet nhuet deleted the gnn-sb3 branch January 20, 2025 09:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants