Skip to content

Commit aa6d0c8

Browse files
authoredJul 13, 2023
Add vector wrappers for lambda observation, action and reward wrappers (#444)
1 parent ee067c7 commit aa6d0c8

File tree

15 files changed

+677
-51
lines changed

15 files changed

+677
-51
lines changed
 

‎docs/api/experimental/vector_wrappers.md

+25-11
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,46 @@ title: Vector Wrappers
88
.. autoclass:: gymnasium.experimental.vector.VectorWrapper
99
```
1010

11-
## Vector Lambda Observation Wrappers
11+
## Vector Observation Wrappers
1212

1313
```{eval-rst}
1414
.. autoclass:: gymnasium.experimental.vector.VectorObservationWrapper
15+
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaObservationV0
16+
.. autoclass:: gymnasium.experimental.wrappers.vector.FilterObservationV0
17+
.. autoclass:: gymnasium.experimental.wrappers.vector.FlattenObservationV0
18+
.. autoclass:: gymnasium.experimental.wrappers.vector.GrayscaleObservationV0
19+
.. autoclass:: gymnasium.experimental.wrappers.vector.ResizeObservationV0
20+
.. autoclass:: gymnasium.experimental.wrappers.vector.ReshapeObservationV0
21+
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleObservationV0
22+
.. autoclass:: gymnasium.experimental.wrappers.vector.DtypeObservationV0
1523
```
1624

17-
## Vector Lambda Action Wrappers
25+
## Vector Action Wrappers
1826

1927
```{eval-rst}
2028
.. autoclass:: gymnasium.experimental.vector.VectorActionWrapper
29+
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaActionV0
30+
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipActionV0
31+
.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleActionV0
2132
```
2233

23-
## Vector Lambda Reward Wrappers
34+
## Vector Reward Wrappers
2435

2536
```{eval-rst}
2637
.. autoclass:: gymnasium.experimental.vector.VectorRewardWrapper
38+
.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaRewardV0
39+
.. autoclass:: gymnasium.experimental.wrappers.vector.ClipRewardV0
2740
```
2841

29-
## Vector Common Wrappers
42+
## More Vector Wrappers
3043

3144
```{eval-rst}
32-
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorRecordEpisodeStatistics
33-
```
34-
35-
## Vector Only Wrappers
36-
37-
```{eval-rst}
38-
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorListInfo
45+
.. autoclass:: gymnasium.experimental.wrappers.vector.RecordEpisodeStatisticsV0
46+
.. autoclass:: gymnasium.experimental.wrappers.vector.DictInfoToListV0
47+
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaObservationV0
48+
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaActionV0
49+
.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaRewardV0
50+
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToNumpyV0
51+
.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToTorchV0
52+
.. autoclass:: gymnasium.experimental.wrappers.vector.NumpyToTorchV0
3953
```

‎gymnasium/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
# necessary for `envs.__init__` which registers all gymnasium environments and loads plugins
2222
from gymnasium import envs
23-
from gymnasium import experimental, spaces, utils, vector, wrappers, error, logger
23+
from gymnasium import spaces, utils, vector, wrappers, error, logger
24+
from gymnasium import experimental
2425

2526

2627
__all__ = [

‎gymnasium/experimental/__init__.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
"""Root __init__ of the gym experimental wrappers."""
22

33

4-
from gymnasium.experimental import functional, wrappers
5-
from gymnasium.experimental.functional import FuncEnv
6-
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
7-
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
8-
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
4+
from gymnasium.experimental import functional, vector, wrappers
5+
6+
7+
# from gymnasium.experimental.functional import FuncEnv
8+
# from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
9+
# from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
10+
# from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
911

1012

1113
__all__ = [
1214
# Functional
13-
"FuncEnv",
15+
# "FuncEnv",
1416
"functional",
1517
# Vector
16-
"VectorEnv",
17-
"VectorWrapper",
18-
"SyncVectorEnv",
19-
"AsyncVectorEnv",
18+
# "VectorEnv",
19+
# "VectorWrapper",
20+
# "SyncVectorEnv",
21+
# "AsyncVectorEnv",
2022
# wrappers
2123
"wrappers",
24+
"vector",
2225
]

‎gymnasium/experimental/vector/vector_env.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -397,32 +397,51 @@ def reset(
397397
) -> tuple[ObsType, dict[str, Any]]:
398398
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
399399
obs, info = self.env.reset(seed=seed, options=options)
400-
return self.observation(obs), info
400+
return self.vector_observation(obs), info
401401

402402
def step(
403403
self, actions: ActType
404404
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
405405
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
406406
observation, reward, termination, truncation, info = self.env.step(actions)
407407
return (
408-
self.observation(observation),
408+
self.vector_observation(observation),
409409
reward,
410410
termination,
411411
truncation,
412-
info,
412+
self.update_final_obs(info),
413413
)
414414

415-
def observation(self, observation: ObsType) -> ObsType:
416-
"""Defines the observation transformation.
415+
def vector_observation(self, observation: ObsType) -> ObsType:
416+
"""Defines the vector observation transformation.
417417
418418
Args:
419-
observation (object): the observation from the environment
419+
observation: A vector observation from the environment
420420
421421
Returns:
422-
observation (object): the transformed observation
422+
the transformed observation
423423
"""
424424
raise NotImplementedError
425425

426+
def single_observation(self, observation: ObsType) -> ObsType:
427+
"""Defines the single observation transformation.
428+
429+
Args:
430+
observation: A single observation from the environment
431+
432+
Returns:
433+
The transformed observation
434+
"""
435+
raise NotImplementedError
436+
437+
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
438+
"""Updates the `final_obs` in the info using `single_observation`."""
439+
if "final_observation" in info:
440+
for i, obs in enumerate(info["final_observation"]):
441+
if obs is not None:
442+
info["final_observation"][i] = self.single_observation(obs)
443+
return info
444+
426445

427446
class VectorActionWrapper(VectorWrapper):
428447
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""

‎gymnasium/experimental/wrappers/vector/__init__.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,57 @@
88
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
99
RecordEpisodeStatisticsV0,
1010
)
11+
from gymnasium.experimental.wrappers.vector.vectorize_action import (
12+
ClipActionV0,
13+
LambdaActionV0,
14+
RescaleActionV0,
15+
VectorizeLambdaActionV0,
16+
)
17+
from gymnasium.experimental.wrappers.vector.vectorize_observation import (
18+
DtypeObservationV0,
19+
FilterObservationV0,
20+
FlattenObservationV0,
21+
GrayscaleObservationV0,
22+
LambdaObservationV0,
23+
RescaleObservationV0,
24+
ReshapeObservationV0,
25+
ResizeObservationV0,
26+
VectorizeLambdaObservationV0,
27+
)
28+
from gymnasium.experimental.wrappers.vector.vectorize_reward import (
29+
ClipRewardV0,
30+
LambdaRewardV0,
31+
VectorizeLambdaRewardV0,
32+
)
1133

1234

1335
__all__ = [
1436
# --- Vector only wrappers
15-
# "VectoriseLambdaObservationV0",
16-
# "VectoriseLambdaActionV0",
17-
# "VectoriseLambdaRewardV0",
37+
"VectorizeLambdaObservationV0",
38+
"VectorizeLambdaActionV0",
39+
"VectorizeLambdaRewardV0",
1840
"DictInfoToListV0",
1941
# --- Observation wrappers ---
20-
# "LambdaObservationV0",
21-
# "FilterObservationV0",
22-
# "FlattenObservationV0",
23-
# "GrayscaleObservationV0",
24-
# "ResizeObservationV0",
25-
# "ReshapeObservationV0",
26-
# "RescaleObservationV0",
27-
# "DtypeObservationV0",
42+
"LambdaObservationV0",
43+
"FilterObservationV0",
44+
"FlattenObservationV0",
45+
"GrayscaleObservationV0",
46+
"ResizeObservationV0",
47+
"ReshapeObservationV0",
48+
"RescaleObservationV0",
49+
"DtypeObservationV0",
2850
# "PixelObservationV0",
2951
# "NormalizeObservationV0",
3052
# "TimeAwareObservationV0",
3153
# "FrameStackObservationV0",
3254
# "DelayObservationV0",
3355
# --- Action Wrappers ---
34-
# "LambdaActionV0",
35-
# "ClipActionV0",
36-
# "RescaleActionV0",
56+
"LambdaActionV0",
57+
"ClipActionV0",
58+
"RescaleActionV0",
3759
# --- Reward wrappers ---
38-
# "LambdaRewardV0",
39-
# "ClipRewardV0",
60+
"LambdaRewardV0",
61+
"ClipRewardV0",
4062
# "NormalizeRewardV1",
4163
# --- Common ---
4264
"RecordEpisodeStatisticsV0",

‎gymnasium/experimental/wrappers/vector/jax_to_numpy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from gymnasium.core import ActType, ObsType
99
from gymnasium.error import DependencyNotInstalled
10-
from gymnasium.experimental import VectorEnv, VectorWrapper
10+
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
1111
from gymnasium.experimental.vector.vector_env import ArrayType
1212
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
1313

@@ -19,7 +19,7 @@ class JaxToNumpyV0(VectorWrapper):
1919
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
2020
2121
Notes:
22-
A vectorised version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
22+
A vectorized version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
2323
2424
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
2525
"""

‎gymnasium/experimental/wrappers/vector/jax_to_torch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
from gymnasium.core import ActType, ObsType
7-
from gymnasium.experimental import VectorEnv, VectorWrapper
7+
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
88
from gymnasium.experimental.vector.vector_env import ArrayType
99
from gymnasium.experimental.wrappers.jax_to_torch import (
1010
Device,

‎gymnasium/experimental/wrappers/vector/numpy_to_torch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
from gymnasium.core import ActType, ObsType
7-
from gymnasium.experimental import VectorEnv, VectorWrapper
7+
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
88
from gymnasium.experimental.vector.vector_env import ArrayType
99
from gymnasium.experimental.wrappers.jax_to_torch import Device
1010
from gymnasium.experimental.wrappers.numpy_to_torch import (

0 commit comments

Comments
 (0)
Please sign in to comment.