Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add low-pass filter to make robot actions smooth. Integrated into ACT and DP. #760

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import smoothen_actions


class ACTPolicy(PreTrainedPolicy):
Expand Down Expand Up @@ -138,6 +139,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:

# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# use low-pass filter to prevent jerky actions
actions = smoothen_actions(actions)

# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
Expand Down
3 changes: 3 additions & 0 deletions lerobot/common/policies/diffusion/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_dtype_from_parameters,
get_output_shape,
populate_queues,
smoothen_actions,
)


Expand Down Expand Up @@ -137,6 +138,8 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:

# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
# use low-pass filter to prevent jerky actions
actions = smoothen_actions(actions)

self._queues["action"].extend(actions.transpose(0, 1))

Expand Down
46 changes: 46 additions & 0 deletions lerobot/common/policies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
from scipy.signal import butter, filtfilt
from torch import nn


Expand Down Expand Up @@ -65,3 +67,47 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
with torch.inference_mode():
output = module(dummy_input)
return tuple(output.shape)


def butterworth_lowpass_filter(
data: np.ndarray, cutoff_freq: float = 1.0, sampling_freq: float = 15.0, order=2
) -> np.ndarray:
"""
Applies a low-pass Butterworth filter to the input data.

Parameters:
data (np.array): Input data array.
cutoff (float): Cutoff frequency of the filter (Hz). Smoother for lower values.
fs (float): Sampling frequency of the data (Hz).
order (int): Order of the filter. Higher order may introduce phase distortions.

Returns:
filtered_data (np.array): Filtered data array with same shape as data.
"""
nyquist = 0.5 * sampling_freq
normal_cutoff = cutoff_freq / nyquist
b, a = butter(order, normal_cutoff, btype="low", analog=False)

# apply the filter along axis 0
filtered_data = filtfilt(b, a, data, axis=0)
return filtered_data


def smoothen_actions(actions: torch.Tensor) -> torch.Tensor:
"""
Smoothens the provided action sequence tensor
Args:
actions (torch.Tensor): actions from policy
"""
if not isinstance(actions, torch.Tensor):
raise ValueError(f"Invalid input type for actions {type(actions)}. Expected torch.Tensor!")

if len(actions.shape) == 3 and not actions.shape[0] == 1:
raise NotImplementedError("Batch processing not implemented!!")

actions_np = actions.squeeze(0).cpu().numpy()
# apply the low-pass filter
filtered_actions_np = butterworth_lowpass_filter(actions_np.copy())
# disable filtering for the gripper joint
filtered_actions_np[:, -1] = actions_np[:, -1]
return torch.from_numpy(filtered_actions_np.copy()).unsqueeze(0).to(actions.device)