@@ -397,32 +397,51 @@ def reset(
397
397
) -> tuple [ObsType , dict [str , Any ]]:
398
398
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
399
399
obs , info = self .env .reset (seed = seed , options = options )
400
- return self .observation (obs ), info
400
+ return self .vector_observation (obs ), info
401
401
402
402
def step (
403
403
self , actions : ActType
404
404
) -> tuple [ObsType , ArrayType , ArrayType , ArrayType , dict ]:
405
405
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
406
406
observation , reward , termination , truncation , info = self .env .step (actions )
407
407
return (
408
- self .observation (observation ),
408
+ self .vector_observation (observation ),
409
409
reward ,
410
410
termination ,
411
411
truncation ,
412
- info ,
412
+ self . update_final_obs ( info ) ,
413
413
)
414
414
415
- def observation (self , observation : ObsType ) -> ObsType :
416
- """Defines the observation transformation.
415
+ def vector_observation (self , observation : ObsType ) -> ObsType :
416
+ """Defines the vector observation transformation.
417
417
418
418
Args:
419
- observation (object): the observation from the environment
419
+ observation: A vector observation from the environment
420
420
421
421
Returns:
422
- observation (object): the transformed observation
422
+ the transformed observation
423
423
"""
424
424
raise NotImplementedError
425
425
426
+ def single_observation (self , observation : ObsType ) -> ObsType :
427
+ """Defines the single observation transformation.
428
+
429
+ Args:
430
+ observation: A single observation from the environment
431
+
432
+ Returns:
433
+ The transformed observation
434
+ """
435
+ raise NotImplementedError
436
+
437
+ def update_final_obs (self , info : dict [str , Any ]) -> dict [str , Any ]:
438
+ """Updates the `final_obs` in the info using `single_observation`."""
439
+ if "final_observation" in info :
440
+ for i , obs in enumerate (info ["final_observation" ]):
441
+ if obs is not None :
442
+ info ["final_observation" ][i ] = self .single_observation (obs )
443
+ return info
444
+
426
445
427
446
class VectorActionWrapper (VectorWrapper ):
428
447
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
0 commit comments