Skip to content

Commit aa228c0

Browse files
authored
Add features back into policy configs (#643)
1 parent 6bd9e12 commit aa228c0

17 files changed

+328
-313
lines changed

lerobot/common/constants.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# keys
2+
OBS_ENV = "observation.environment_state"
3+
OBS_ROBOT = "observation.state"
4+
OBS_IMAGE = "observation.image"
5+
OBS_IMAGES = "observation.images"
6+
ACTION = "action"

lerobot/common/datasets/utils.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from torchvision import transforms
3636

3737
from lerobot.common.robot_devices.robots.utils import Robot
38-
from lerobot.configs.types import DictLike
38+
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
3939

4040
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
4141

@@ -302,6 +302,37 @@ def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
302302
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
303303

304304

305+
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
306+
# TODO(aliberts): Implement "type" in dataset features and simplify this
307+
policy_features = {}
308+
for key, ft in features.items():
309+
shape = ft["shape"]
310+
if ft["dtype"] in ["image", "video"]:
311+
type = FeatureType.VISUAL
312+
if len(shape) != 3:
313+
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
314+
315+
names = ft["names"]
316+
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
317+
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
318+
shape = (shape[2], shape[0], shape[1])
319+
elif key == "observation.environment_state":
320+
type = FeatureType.ENV
321+
elif key.startswith("observation"):
322+
type = FeatureType.STATE
323+
elif key == "action":
324+
type = FeatureType.ACTION
325+
else:
326+
continue
327+
328+
policy_features[key] = PolicyFeature(
329+
type=type,
330+
shape=shape,
331+
)
332+
333+
return policy_features
334+
335+
305336
def create_empty_dataset_info(
306337
codebase_version: str,
307338
fps: int,

lerobot/common/envs/configs.py

+58-30
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33

44
import draccus
55

6-
from lerobot.configs.types import FeatureType
6+
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
7+
from lerobot.configs.types import FeatureType, PolicyFeature
78

89

910
@dataclass
1011
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
1112
n_envs: int | None = None
1213
task: str | None = None
1314
fps: int = 30
14-
feature_types: dict = field(default_factory=dict)
15+
features: dict[str, PolicyFeature] = field(default_factory=dict)
16+
features_map: dict[str, str] = field(default_factory=dict)
1517

1618
@property
1719
def type(self) -> str:
@@ -28,17 +30,28 @@ class AlohaEnv(EnvConfig):
2830
task: str = "AlohaInsertion-v0"
2931
fps: int = 50
3032
episode_length: int = 400
31-
feature_types: dict = field(
33+
obs_type: str = "pixels_agent_pos"
34+
render_mode: str = "rgb_array"
35+
features: dict[str, PolicyFeature] = field(
3236
default_factory=lambda: {
33-
"agent_pos": FeatureType.STATE,
34-
"pixels": {
35-
"top": FeatureType.VISUAL,
36-
},
37-
"action": FeatureType.ACTION,
37+
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
38+
}
39+
)
40+
features_map: dict[str, str] = field(
41+
default_factory=lambda: {
42+
"action": ACTION,
43+
"agent_pos": OBS_ROBOT,
44+
"top": f"{OBS_IMAGE}.top",
45+
"pixels/top": f"{OBS_IMAGES}.top",
3846
}
3947
)
40-
obs_type: str = "pixels_agent_pos"
41-
render_mode: str = "rgb_array"
48+
49+
def __post_init__(self):
50+
if self.obs_type == "pixels":
51+
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
52+
elif self.obs_type == "pixels_agent_pos":
53+
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
54+
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
4255

4356
@property
4457
def gym_kwargs(self) -> dict:
@@ -55,25 +68,30 @@ class PushtEnv(EnvConfig):
5568
task: str = "PushT-v0"
5669
fps: int = 10
5770
episode_length: int = 300
58-
feature_types: dict = field(
59-
default_factory=lambda: {
60-
"agent_pos": FeatureType.STATE,
61-
"pixels": FeatureType.VISUAL,
62-
"action": FeatureType.ACTION,
63-
}
64-
)
6571
obs_type: str = "pixels_agent_pos"
6672
render_mode: str = "rgb_array"
6773
visualization_width: int = 384
6874
visualization_height: int = 384
75+
features: dict[str, PolicyFeature] = field(
76+
default_factory=lambda: {
77+
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
78+
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
79+
}
80+
)
81+
features_map: dict[str, str] = field(
82+
default_factory=lambda: {
83+
"action": ACTION,
84+
"agent_pos": OBS_ROBOT,
85+
"environment_state": OBS_ENV,
86+
"pixels": OBS_IMAGE,
87+
}
88+
)
6989

7090
def __post_init__(self):
71-
if self.obs_type == "environment_state_agent_pos":
72-
self.feature_types = {
73-
"agent_pos": FeatureType.STATE,
74-
"environment_state": FeatureType.ENV,
75-
"action": FeatureType.ACTION,
76-
}
91+
if self.obs_type == "pixels_agent_pos":
92+
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
93+
elif self.obs_type == "environment_state_agent_pos":
94+
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
7795

7896
@property
7997
def gym_kwargs(self) -> dict:
@@ -91,17 +109,27 @@ class XarmEnv(EnvConfig):
91109
task: str = "XarmLift-v0"
92110
fps: int = 15
93111
episode_length: int = 200
94-
feature_types: dict = field(
95-
default_factory=lambda: {
96-
"agent_pos": FeatureType.STATE,
97-
"pixels": FeatureType.VISUAL,
98-
"action": FeatureType.ACTION,
99-
}
100-
)
101112
obs_type: str = "pixels_agent_pos"
102113
render_mode: str = "rgb_array"
103114
visualization_width: int = 384
104115
visualization_height: int = 384
116+
features: dict[str, PolicyFeature] = field(
117+
default_factory=lambda: {
118+
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
119+
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
120+
}
121+
)
122+
features_map: dict[str, str] = field(
123+
default_factory=lambda: {
124+
"action": ACTION,
125+
"agent_pos": OBS_ROBOT,
126+
"pixels": OBS_IMAGE,
127+
}
128+
)
129+
130+
def __post_init__(self):
131+
if self.obs_type == "pixels_agent_pos":
132+
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
105133

106134
@property
107135
def gym_kwargs(self) -> dict:

lerobot/common/envs/utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import torch
1919
from torch import Tensor
2020

21+
from lerobot.common.envs.configs import EnvConfig
22+
from lerobot.common.utils.utils import get_channel_first_image_shape
23+
from lerobot.configs.types import FeatureType, PolicyFeature
24+
2125

2226
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
2327
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
@@ -36,6 +40,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
3640
imgs = {"observation.image": observations["pixels"]}
3741

3842
for imgkey, img in imgs.items():
43+
# TODO(aliberts, rcadene): use transforms.ToTensor()?
3944
img = torch.from_numpy(img)
4045

4146
# sanity check that images are channel last
@@ -61,3 +66,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
6166
# requirement for "agent_pos"
6267
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
6368
return return_observations
69+
70+
71+
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
72+
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
73+
# (need to also refactor preprocess_observation and externalize normalization from policies)
74+
policy_features = {}
75+
for key, ft in env_cfg.features.items():
76+
if ft.type is FeatureType.VISUAL:
77+
if len(ft.shape) != 3:
78+
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
79+
80+
shape = get_channel_first_image_shape(ft.shape)
81+
feature = PolicyFeature(type=ft.type, shape=shape)
82+
else:
83+
feature = ft
84+
85+
policy_key = env_cfg.features_map[key]
86+
policy_features[policy_key] = feature
87+
88+
return policy_features

lerobot/common/policies/act/modeling_act.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ def __init__(
6868
config.validate_features()
6969
self.config = config
7070

71-
self.normalize_inputs = Normalize(config.input_features, dataset_stats)
72-
self.normalize_targets = Normalize(config.output_features, dataset_stats)
73-
self.unnormalize_outputs = Unnormalize(config.output_features, dataset_stats)
71+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
72+
self.normalize_targets = Normalize(
73+
config.output_features, config.normalization_mapping, dataset_stats
74+
)
75+
self.unnormalize_outputs = Unnormalize(
76+
config.output_features, config.normalization_mapping, dataset_stats
77+
)
7478

7579
self.model = ACT(config)
7680

@@ -121,7 +125,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
121125
if self.config.image_features:
122126
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
123127
batch["observation.images"] = torch.stack(
124-
[batch[ft.key] for ft in self.config.image_features], dim=-4
128+
[batch[key] for key in self.config.image_features], dim=-4
125129
)
126130

127131
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
@@ -151,7 +155,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
151155
if self.config.image_features:
152156
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
153157
batch["observation.images"] = torch.stack(
154-
[batch[ft.key] for ft in self.config.image_features], dim=-4
158+
[batch[key] for key in self.config.image_features], dim=-4
155159
)
156160
batch = self.normalize_targets(batch)
157161
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -411,7 +415,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
411415
"""
412416
if self.config.use_vae and self.training:
413417
assert (
414-
self.config.action_feature.key in batch
418+
"action" in batch
415419
), "actions must be provided when using the variational objective in training mode."
416420

417421
batch_size = (

lerobot/common/policies/diffusion/configuration_diffusion.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -208,21 +208,20 @@ def validate_features(self) -> None:
208208
raise ValueError("You must provide at least one image or the environment state among the inputs.")
209209

210210
if self.crop_shape is not None:
211-
for image_ft in self.image_features:
211+
for key, image_ft in self.image_features.items():
212212
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
213213
raise ValueError(
214214
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
215215
f"for `crop_shape` and {image_ft.shape} for "
216-
f"`{image_ft.key}`."
216+
f"`{key}`."
217217
)
218218

219219
# Check that all input images have the same shape.
220-
first_image_ft = next(iter(self.image_features))
221-
for image_ft in self.image_features:
220+
first_image_key, first_image_ft = next(iter(self.image_features.items()))
221+
for key, image_ft in self.image_features.items():
222222
if image_ft.shape != first_image_ft.shape:
223223
raise ValueError(
224-
f"`{image_ft.key}` does not match `{first_image_ft.key}`, but we "
225-
"expect all image shapes to match."
224+
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
226225
)
227226

228227
@property

lerobot/common/policies/diffusion/modeling_diffusion.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from huggingface_hub import PyTorchModelHubMixin
3535
from torch import Tensor, nn
3636

37+
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
3738
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
3839
from lerobot.common.policies.normalize import Normalize, Unnormalize
3940
from lerobot.common.policies.utils import (
@@ -74,9 +75,13 @@ def __init__(
7475
config.validate_features()
7576
self.config = config
7677

77-
self.normalize_inputs = Normalize(config.input_features, dataset_stats)
78-
self.normalize_targets = Normalize(config.output_features, dataset_stats)
79-
self.unnormalize_outputs = Unnormalize(config.output_features, dataset_stats)
78+
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
79+
self.normalize_targets = Normalize(
80+
config.output_features, config.normalization_mapping, dataset_stats
81+
)
82+
self.unnormalize_outputs = Unnormalize(
83+
config.output_features, config.normalization_mapping, dataset_stats
84+
)
8085

8186
# queues are populated during rollout of the policy, they contain the n latest observations and actions
8287
self._queues = None
@@ -125,7 +130,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
125130
if self.config.image_features:
126131
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
127132
batch["observation.images"] = torch.stack(
128-
[batch[ft.key] for ft in self.config.image_features], dim=-4
133+
[batch[key] for key in self.config.image_features], dim=-4
129134
)
130135
# Note: It's important that this happens after stacking the images into a single key.
131136
self._queues = populate_queues(self._queues, batch)
@@ -149,7 +154,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
149154
if self.config.image_features:
150155
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
151156
batch["observation.images"] = torch.stack(
152-
[batch[ft.key] for ft in self.config.image_features], dim=-4
157+
[batch[key] for key in self.config.image_features], dim=-4
153158
)
154159
batch = self.normalize_targets(batch)
155160
loss = self.diffusion.compute_loss(batch)
@@ -237,8 +242,8 @@ def conditional_sample(
237242

238243
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
239244
"""Encode image features and concatenate them all together along with the state vector."""
240-
batch_size, n_obs_steps = batch[self.config.robot_state_feature.key].shape[:2]
241-
global_cond_feats = [batch[self.config.robot_state_feature.key]]
245+
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
246+
global_cond_feats = [batch[OBS_ROBOT]]
242247
# Extract image features.
243248
if self.config.image_features:
244249
if self.config.use_separate_rgb_encoder_per_camera:
@@ -268,7 +273,7 @@ def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
268273
global_cond_feats.append(img_features)
269274

270275
if self.config.env_state_feature:
271-
global_cond_feats.append(batch[self.config.env_state_feature.key])
276+
global_cond_feats.append(batch[OBS_ENV])
272277

273278
# Concatenate features then flatten to (B, global_cond_dim).
274279
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
@@ -482,10 +487,9 @@ def __init__(self, config: DiffusionConfig):
482487
# height and width from `config.image_features`.
483488

484489
# Note: we have a check in the config class to make sure all images have the same shape.
485-
dummy_shape_h_w = (
486-
config.crop_shape if config.crop_shape is not None else config.image_features[0].shape[1:]
487-
)
488-
dummy_shape = (1, config.image_features[0].shape[0], *dummy_shape_h_w)
490+
images_shape = next(iter(config.image_features.values())).shape
491+
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
492+
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
489493
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
490494

491495
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)

0 commit comments

Comments
 (0)