Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a GNN PPO for ray-rllib #460

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions skdecide/hub/solver/ray_rllib/custom_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from gymnasium.spaces import flatten_space
from ray.rllib import SampleBatch
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as TFFullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.fcnet import (
Expand All @@ -9,6 +10,15 @@
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, unbatch
from ray.rllib.utils.torch_utils import FLOAT_MAX, FLOAT_MIN

from skdecide.hub.solver.ray_rllib.gnn.models.torch.complex_input_net import (
GraphComplexInputNetwork,
)
from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel
from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import (
is_graph_dict_multiinput_space,
is_graph_dict_space,
)

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()

Expand Down Expand Up @@ -98,8 +108,20 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name, **k
self.action_ids_shifted = torch.arange(1, num_outputs + 1, dtype=torch.int64)
self.true_obs_space = model_config["custom_model_config"]["true_obs_space"]

self.pred_action_embed_model = TorchFullyConnectedNetwork(
flatten_space(self.true_obs_space),
if is_graph_dict_space(self.true_obs_space):
pred_action_embed_model_cls = GnnBasedModel
self.obs_with_graph = True
embed_model_obs_space = self.true_obs_space
elif is_graph_dict_multiinput_space(self.true_obs_space):
pred_action_embed_model_cls = GraphComplexInputNetwork
self.obs_with_graph = True
embed_model_obs_space = self.true_obs_space
else:
pred_action_embed_model_cls = TorchFullyConnectedNetwork
self.obs_with_graph = False
embed_model_obs_space = flatten_space(self.true_obs_space)
self.pred_action_embed_model = pred_action_embed_model_cls(
embed_model_obs_space,
action_space,
model_config["custom_model_config"]["action_embed_size"],
model_config,
Expand All @@ -115,16 +137,21 @@ def forward(self, input_dict, state, seq_lens):
# Extract the available actions mask tensor from the observation.
valid_avail_actions_mask = input_dict["obs"]["valid_avail_actions_mask"]

# Unbatch true observations before flattening them
unbatched_true_obs = unbatch(input_dict["obs"]["true_obs"])
if self.obs_with_graph:
# use directly the obs (already converted at proper format by custom `convert_to_torch_tensor`)
embed_model_obs = input_dict["obs"]["true_obs"]
else:
# Unbatch true observations before flattening them
embed_model_obs = torch.stack(
[
flatten_to_single_ndarray(o)
for o in unbatch(input_dict["obs"]["true_obs"])
]
)

# Compute the predicted action embedding
pred_action_embed, _ = self.pred_action_embed_model(
{
"obs": torch.stack(
[flatten_to_single_ndarray(o) for o in unbatched_true_obs]
)
}
SampleBatch({SampleBatch.OBS: embed_model_obs})
)

# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
Expand Down
Empty file.
1 change: 1 addition & 0 deletions skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ppo.ppo import GraphPPO
Empty file.
21 changes: 21 additions & 0 deletions skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Optional

from ray.rllib import Policy
from ray.rllib.algorithms import PPO, AlgorithmConfig

from skdecide.hub.solver.ray_rllib.gnn.algorithms.ppo.ppo_torch_policy import (
PPOTorchGraphPolicy,
)


class GraphPPO(PPO):
@classmethod
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[type[Policy]]:
if config["framework"] == "torch":
return PPOTorchGraphPolicy
elif config["framework"] == "tf":
raise NotImplementedError("GraphPPO implemented for torch context")
else:
raise NotImplementedError("GraphPPO implemented for torch context")
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ray.rllib.algorithms.ppo import PPOTorchPolicy

from skdecide.hub.solver.ray_rllib.gnn.policy.torch_graph_policy import TorchGraphPolicy


class PPOTorchGraphPolicy(TorchGraphPolicy, PPOTorchPolicy):
...
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import gymnasium as gym
from ray.rllib import SampleBatch
from ray.rllib.models.torch.complex_input_net import ComplexInputNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import TensorType
from torch import nn

from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel
from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import (
is_graph_dict_space,
)


class GraphComplexInputNetwork(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
if not model_config.get("_disable_preprocessor_api"):
raise ValueError(
"This model is intent to be used only when preprocessors are disabled."
)
if not isinstance(obs_space, gym.spaces.Dict):
raise ValueError(
"This model is intent to be used only on dict observation space."
)

nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config, name)

self.gnn = nn.ModuleDict()
post_graph_obs_subspaces = dict(obs_space.spaces)
for k, subspace in obs_space.spaces.items():
if is_graph_dict_space(subspace):
submodel_name = f"gnn_{k}"
gnn = GnnBasedModel(
obs_space=subspace,
action_space=action_space,
num_outputs=None,
model_config=model_config,
framework="torch",
name=submodel_name,
)
self.add_module(submodel_name, gnn)
self.gnn[k] = gnn
post_graph_obs_subspaces[k] = gnn.features_space

post_graph_obs_space = gym.spaces.Dict(post_graph_obs_subspaces)
self.post_graph_model = ComplexInputNetwork(
obs_space=post_graph_obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=model_config,
name="post_graph_model",
)

def forward(self, input_dict: SampleBatch, state, seq_lens):
post_graph_input_dict = input_dict.copy(shallow=True)
obs = input_dict["obs"]
post_graph_obs = dict(obs)
for k, gnn in self.gnn.items():
post_graph_obs[k] = gnn(SampleBatch({SampleBatch.OBS: obs[k]}))
post_graph_input_dict["obs"] = post_graph_obs
return self.post_graph_model(
input_dict=post_graph_input_dict, state=state, seq_lens=seq_lens
)

def value_function(self) -> TensorType:
return self.post_graph_model.value_function()
77 changes: 77 additions & 0 deletions skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from collections import defaultdict
from typing import Optional

import gymnasium as gym
import numpy as np
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.typing import ModelConfigDict
from torch import nn

from skdecide.hub.solver.ray_rllib.gnn.torch_layers import GraphFeaturesExtractor
from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import (
convert_dict_space_to_graph_space,
is_graph_dict_space,
)


class GnnBasedModel(TorchModelV2, nn.Module):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: Optional[int],
model_config: ModelConfigDict,
name: str,
**kw,
):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config, name)

# config for custom model
custom_config = defaultdict(
lambda: None, # will return None for missing keys
model_config.get("custom_model_config", {}),
)

# gnn-based feature extractor
features_extractor_kwargs = custom_config.get("features_extractor", {})
assert is_graph_dict_space(
obs_space
), f"{self.__class__.__name__} can only be applied to Graph observation spaces."
graph_observation_space = convert_dict_space_to_graph_space(obs_space)
self.features_extractor = GraphFeaturesExtractor(
observation_space=graph_observation_space, **features_extractor_kwargs
)
self.features_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(self.features_extractor.features_dim,)
)

if num_outputs is None:
# only feature extraction (e.g. to be used by GraphComplexInputNetwork)
self.num_outputs = self.features_extractor.features_dim
self.pred_action_embed_model = None
else:
# fully connected network
self.pred_action_embed_model = FullyConnectedNetwork(
obs_space=self.features_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=model_config,
name=name + "_pred_action_embed",
)

def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs"]
features = self.features_extractor(obs)
if self.pred_action_embed_model is None:
return features, state
else:
return self.pred_action_embed_model(
input_dict={"obs": features},
state=state,
seq_lens=seq_lens,
)

def value_function(self):
return self.pred_action_embed_model.value_function()
Empty file.
116 changes: 116 additions & 0 deletions skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from numbers import Number
from typing import Any, Optional, Union

import gymnasium as gym
import numpy as np
import tree
from ray.rllib import SampleBatch
from ray.rllib.policy.sample_batch import attempt_count_timesteps, tf, torch
from ray.rllib.utils.typing import ViewRequirementsDict


def _pop_graph_items(
full_dict: dict[Any, Any]
) -> dict[Any, Union[gym.spaces.GraphInstance, list[gym.spaces.GraphInstance]]]:
graph_dict = {}
for k, v in full_dict.items():
if isinstance(v, gym.spaces.GraphInstance) or (
isinstance(v, list) and isinstance(v[0], gym.spaces.GraphInstance)
):
graph_dict[k] = v
for k in graph_dict:
full_dict.pop(k)
return graph_dict


def _split_graph_requirements(
full_dict: ViewRequirementsDict,
) -> tuple[ViewRequirementsDict, ViewRequirementsDict]:
graph_dict = {}
for k, v in full_dict.items():
if isinstance(v.space, gym.spaces.Graph):
graph_dict[k] = v
wo_graph_dict = {k: v for k, v in full_dict.items() if k not in graph_dict}
return graph_dict, wo_graph_dict


class GraphSampleBatch(SampleBatch):
def __init__(self, *args, **kwargs):
"""Constructs a sample batch with possibly graph obs.

See `ray.rllib.SampleBatch` for more information.

"""
# split graph samples from others.
dict_graphs = _pop_graph_items(kwargs)
dict_from_args = dict(*args)
dict_graphs.update(_pop_graph_items(dict_from_args))

super().__init__(dict_from_args, **kwargs)
super().update(dict_graphs)

def copy(self, shallow: bool = False) -> "SampleBatch":
"""Creates a deep or shallow copy of this SampleBatch and returns it.

Args:
shallow: Whether the copying should be done shallowly.

Returns:
A deep or shallow copy of this SampleBatch object.
"""
copy_ = dict(self)
data = tree.map_structure(
lambda v: (
np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v
),
copy_,
)
copy_ = GraphSampleBatch(
data,
_time_major=self.time_major,
_zero_padded=self.zero_padded,
_max_seq_len=self.max_seq_len,
_num_grad_updates=self.num_grad_updates,
)
copy_.set_get_interceptor(self.get_interceptor)
copy_.added_keys = self.added_keys
copy_.deleted_keys = self.deleted_keys
copy_.accessed_keys = self.accessed_keys
return copy_

def get_single_step_input_dict(
self,
view_requirements: ViewRequirementsDict,
index: Union[str, int] = "last",
) -> "SampleBatch":
(
view_requirements_graphs,
view_requirements_wo_graphs,
) = _split_graph_requirements(view_requirements)
# w/o graphs
sample = GraphSampleBatch(
super().get_single_step_input_dict(view_requirements_wo_graphs, index)
)
# handle graphs
last_mappings = {
SampleBatch.OBS: SampleBatch.NEXT_OBS,
SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
}
for view_col, view_req in view_requirements_graphs.items():
if view_req.used_for_compute_actions is False:
continue

# Create batches of size 1 (single-agent input-dict).
data_col = view_req.data_col or view_col
if index == "last":
data_col = last_mappings.get(data_col, data_col)
if view_req.shift_from is not None:
raise NotImplementedError()
else:
sample[view_col] = self[data_col][-1:]
else:
sample[view_col] = self[data_col][
index : index + 1 if index != -1 else None
]
return sample
16 changes: 16 additions & 0 deletions skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import functools

from ray.rllib import SampleBatch
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2

from skdecide.hub.solver.ray_rllib.gnn.utils.torch_utils import convert_to_torch_tensor


class TorchGraphPolicy(TorchPolicyV2):
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
if not isinstance(postprocessed_batch, SampleBatch):
postprocessed_batch = SampleBatch(postprocessed_batch)
postprocessed_batch.set_get_interceptor(
functools.partial(convert_to_torch_tensor, device=device or self.device)
)
return postprocessed_batch
Loading
Loading