3
3
4
4
import draccus
5
5
6
- from lerobot .configs .types import FeatureType
6
+ from lerobot .common .constants import ACTION , OBS_ENV , OBS_IMAGE , OBS_IMAGES , OBS_ROBOT
7
+ from lerobot .configs .types import FeatureType , PolicyFeature
7
8
8
9
9
10
@dataclass
10
11
class EnvConfig (draccus .ChoiceRegistry , abc .ABC ):
11
12
n_envs : int | None = None
12
13
task : str | None = None
13
14
fps : int = 30
14
- feature_types : dict = field (default_factory = dict )
15
+ features : dict [str , PolicyFeature ] = field (default_factory = dict )
16
+ features_map : dict [str , str ] = field (default_factory = dict )
15
17
16
18
@property
17
19
def type (self ) -> str :
@@ -28,17 +30,28 @@ class AlohaEnv(EnvConfig):
28
30
task : str = "AlohaInsertion-v0"
29
31
fps : int = 50
30
32
episode_length : int = 400
31
- feature_types : dict = field (
33
+ obs_type : str = "pixels_agent_pos"
34
+ render_mode : str = "rgb_array"
35
+ features : dict [str , PolicyFeature ] = field (
32
36
default_factory = lambda : {
33
- "agent_pos" : FeatureType .STATE ,
34
- "pixels" : {
35
- "top" : FeatureType .VISUAL ,
36
- },
37
- "action" : FeatureType .ACTION ,
37
+ "action" : PolicyFeature (type = FeatureType .ACTION , shape = (14 ,)),
38
+ }
39
+ )
40
+ features_map : dict [str , str ] = field (
41
+ default_factory = lambda : {
42
+ "action" : ACTION ,
43
+ "agent_pos" : OBS_ROBOT ,
44
+ "top" : f"{ OBS_IMAGE } .top" ,
45
+ "pixels/top" : f"{ OBS_IMAGES } .top" ,
38
46
}
39
47
)
40
- obs_type : str = "pixels_agent_pos"
41
- render_mode : str = "rgb_array"
48
+
49
+ def __post_init__ (self ):
50
+ if self .obs_type == "pixels" :
51
+ self .features ["top" ] = PolicyFeature (type = FeatureType .VISUAL , shape = (480 , 640 , 3 ))
52
+ elif self .obs_type == "pixels_agent_pos" :
53
+ self .features ["agent_pos" ] = PolicyFeature (type = FeatureType .STATE , shape = (14 ,))
54
+ self .features ["pixels/top" ] = PolicyFeature (type = FeatureType .VISUAL , shape = (480 , 640 , 3 ))
42
55
43
56
@property
44
57
def gym_kwargs (self ) -> dict :
@@ -55,25 +68,30 @@ class PushtEnv(EnvConfig):
55
68
task : str = "PushT-v0"
56
69
fps : int = 10
57
70
episode_length : int = 300
58
- feature_types : dict = field (
59
- default_factory = lambda : {
60
- "agent_pos" : FeatureType .STATE ,
61
- "pixels" : FeatureType .VISUAL ,
62
- "action" : FeatureType .ACTION ,
63
- }
64
- )
65
71
obs_type : str = "pixels_agent_pos"
66
72
render_mode : str = "rgb_array"
67
73
visualization_width : int = 384
68
74
visualization_height : int = 384
75
+ features : dict [str , PolicyFeature ] = field (
76
+ default_factory = lambda : {
77
+ "action" : PolicyFeature (type = FeatureType .ACTION , shape = (2 ,)),
78
+ "agent_pos" : PolicyFeature (type = FeatureType .STATE , shape = (2 ,)),
79
+ }
80
+ )
81
+ features_map : dict [str , str ] = field (
82
+ default_factory = lambda : {
83
+ "action" : ACTION ,
84
+ "agent_pos" : OBS_ROBOT ,
85
+ "environment_state" : OBS_ENV ,
86
+ "pixels" : OBS_IMAGE ,
87
+ }
88
+ )
69
89
70
90
def __post_init__ (self ):
71
- if self .obs_type == "environment_state_agent_pos" :
72
- self .feature_types = {
73
- "agent_pos" : FeatureType .STATE ,
74
- "environment_state" : FeatureType .ENV ,
75
- "action" : FeatureType .ACTION ,
76
- }
91
+ if self .obs_type == "pixels_agent_pos" :
92
+ self .features ["pixels" ] = PolicyFeature (type = FeatureType .VISUAL , shape = (384 , 384 , 3 ))
93
+ elif self .obs_type == "environment_state_agent_pos" :
94
+ self .features ["environment_state" ] = PolicyFeature (type = FeatureType .ENV , shape = (16 ,))
77
95
78
96
@property
79
97
def gym_kwargs (self ) -> dict :
@@ -91,17 +109,27 @@ class XarmEnv(EnvConfig):
91
109
task : str = "XarmLift-v0"
92
110
fps : int = 15
93
111
episode_length : int = 200
94
- feature_types : dict = field (
95
- default_factory = lambda : {
96
- "agent_pos" : FeatureType .STATE ,
97
- "pixels" : FeatureType .VISUAL ,
98
- "action" : FeatureType .ACTION ,
99
- }
100
- )
101
112
obs_type : str = "pixels_agent_pos"
102
113
render_mode : str = "rgb_array"
103
114
visualization_width : int = 384
104
115
visualization_height : int = 384
116
+ features : dict [str , PolicyFeature ] = field (
117
+ default_factory = lambda : {
118
+ "action" : PolicyFeature (type = FeatureType .ACTION , shape = (4 ,)),
119
+ "pixels" : PolicyFeature (type = FeatureType .VISUAL , shape = (84 , 84 , 3 )),
120
+ }
121
+ )
122
+ features_map : dict [str , str ] = field (
123
+ default_factory = lambda : {
124
+ "action" : ACTION ,
125
+ "agent_pos" : OBS_ROBOT ,
126
+ "pixels" : OBS_IMAGE ,
127
+ }
128
+ )
129
+
130
+ def __post_init__ (self ):
131
+ if self .obs_type == "pixels_agent_pos" :
132
+ self .features ["agent_pos" ] = PolicyFeature (type = FeatureType .STATE , shape = (4 ,))
105
133
106
134
@property
107
135
def gym_kwargs (self ) -> dict :
0 commit comments