Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a GNN PPO for ray-rllib #460

Merged
merged 1 commit into from
Jan 17, 2025
Merged

Implement a GNN PPO for ray-rllib #460

merged 1 commit into from
Jan 17, 2025

Conversation

nhuet
Copy link
Contributor

@nhuet nhuet commented Jan 13, 2025

We follow the same guidelines as for the sb3 wrapper:

  • GNN based on pytorch-geometric
  • Feature extraction via GNN + reduction layer to a fixed number of
    feature
  • Observation = Graph or dict whose values contains at least one Graph
  • Action masks are taken into account if available
  • User must use GraphPPO instead of PPO as algorithm: GraphPPO overrides
    PPO to change the way obs is converted to pytorch format

Worth noticing:

  • We use the old api stack as the RLlib wrapper is currently using it
  • For graph observations, the model is gnn extractor followed by a FullyConnectedNetwork
  • For dict of graphs (and other) observations, the model is
    • preprocess obs by using gnn features extractor for graph components
    • apply to the prepreocessed obs a ComplexInputNetwork
  • action masking is automatically activated according to domain class
    (not UnrestrictedActions) and algo class, as it was already coded in
    RayRLlib wrapper. The algo to be used is still GraphPPO as masking is
    managed by a custom model at RayRLlib wrapper level.

@nhuet nhuet changed the title Implement a GNN PPO based on ray-rllib + torch_geometric Implement a GNN PPO for ray-rllib Jan 16, 2025
We follow the same guidelines as for the sb3 wrapper:
- GNN based on pytorch-geometric
- Feature extraction via GNN + reduction layer to a fixed number of
  feature
- Observation = Graph or dict whose values contains at least one Graph
- Action masks are taken into account if available
- User must use GraphPPO instead of PPO as algorithm: GraphPPO overrides
  PPO to change the way obs is converted to pytorch format

Worth noticing:
- We use the old api stack as the RLlib wrapper is currently using it
- For graph observations, the model is gnn extractor followed by a FullyConnectedNetwork
- For dict of graphs (and other) observations, the model is
  - preprocess obs by using gnn features extractor for graph components
  - apply to the prepreocessed obs a ComplexInputNetwork
- action masking is automatically activated according to domain class
  (not UnrestrictedActions) and algo class, as it was already coded in
  RayRLlib wrapper. The algo to be used is still GraphPPO as masking is
  managed by a custom model at RayRLlib wrapper level.
@fteicht fteicht merged commit 324ebf4 into airbus:master Jan 17, 2025
33 checks passed
@nhuet nhuet deleted the ray-gnn branch January 20, 2025 09:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants