From aac99cdf86f998c09d3e5266182f3016e005b918 Mon Sep 17 00:00:00 2001 From: Gijs de Jong Date: Mon, 10 Feb 2025 11:22:42 +0100 Subject: [PATCH] infer entity paths from feature names in dataset --- crates/store/re_data_loader/src/lerobot.rs | 23 +++++++++++++++ .../re_data_loader/src/loader_lerobot.rs | 29 +++++++++++++------ 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/crates/store/re_data_loader/src/lerobot.rs b/crates/store/re_data_loader/src/lerobot.rs index 1809a362549c..69bb51c4cf74 100644 --- a/crates/store/re_data_loader/src/lerobot.rs +++ b/crates/store/re_data_loader/src/lerobot.rs @@ -352,6 +352,7 @@ impl LeRobotDatasetInfo { pub struct Feature { pub dtype: DType, pub shape: Vec, + pub names: Option, } /// Data types supported for features in a `LeRobot` dataset. @@ -366,6 +367,28 @@ pub enum DType { Int64, } +/// Name metadata for a feature in the `LeRobot` dataset. +/// +/// The name metadata can consist of +/// - A flat list of names for each dimension of a feature (e.g., `["height", "width", "channel"]`). +/// - A list specific to motors (e.g., `{ "motors": ["motor_0", "motor_1", ...] }`). +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Names { + Motors { motors: Vec }, + List(Vec), +} + +impl Names { + /// Retrieves the name corresponding to a specific index within the `names` field of a feature. + pub fn name_for_index(&self, index: usize) -> Option<&String> { + match self { + Self::Motors { motors } => motors.get(index), + Self::List(items) => items.get(index), + } + } +} + // TODO(gijsd): Do we want to stream in episodes or tasks? #[cfg(not(target_arch = "wasm32"))] fn load_jsonl_file(filepath: impl AsRef) -> Result, LeRobotError> diff --git a/crates/store/re_data_loader/src/loader_lerobot.rs b/crates/store/re_data_loader/src/loader_lerobot.rs index 9b6897623e01..b69631dbf855 100644 --- a/crates/store/re_data_loader/src/loader_lerobot.rs +++ b/crates/store/re_data_loader/src/loader_lerobot.rs @@ -16,7 +16,7 @@ use re_types::archetypes::{AssetVideo, EncodedImage, VideoFrameReference}; use re_types::components::{Scalar, VideoTimestamp}; use re_types::{Archetype, Component, ComponentBatch}; -use crate::lerobot::{is_le_robot_dataset, DType, EpisodeIndex, LeRobotDataset}; +use crate::lerobot::{is_le_robot_dataset, DType, EpisodeIndex, Feature, LeRobotDataset}; use crate::{DataLoader, DataLoaderError, LoadedData}; /// Columns in the `LeRobot` dataset schema that we do not visualize in the viewer, and thus ignore. @@ -112,7 +112,7 @@ impl DataLoader for LeRobotDatasetLoader { ); } DType::Float32 | DType::Float64 => { - chunks.extend(load_scalar(feature_key, &timelines, &data)?); + chunks.extend(load_scalar(feature_key, feature, &timelines, &data)?); } } } @@ -269,19 +269,22 @@ impl Iterator for ScalarChunkIterator { impl ExactSizeIterator for ScalarChunkIterator {} fn load_scalar( - feature: &str, + feature_key: &str, + feature: &Feature, timelines: &IntMap, data: &RecordBatch, ) -> Result { let field = data .schema_ref() - .field_with_name(feature) - .with_context(|| format!("Failed to get field for feature {feature} from parquet file"))?; + .field_with_name(feature_key) + .with_context(|| { + format!("Failed to get field for feature {feature_key} from parquet file") + })?; match field.data_type() { DataType::FixedSizeList(_, _) => { let fixed_size_array = data - .column_by_name(feature) + .column_by_name(feature_key) .and_then(|col| col.downcast_array_ref::()) .ok_or_else(|| { DataLoaderError::Other(anyhow!( @@ -289,11 +292,12 @@ fn load_scalar( )) })?; - let batch_chunks = make_scalar_batch_entity_chunks(field, timelines, fixed_size_array)?; + let batch_chunks = + make_scalar_batch_entity_chunks(field, feature, timelines, fixed_size_array)?; Ok(ScalarChunkIterator::Batch(Box::new(batch_chunks))) } DataType::Float32 => { - let feature_data = data.column_by_name(feature).ok_or_else(|| { + let feature_data = data.column_by_name(feature_key).ok_or_else(|| { DataLoaderError::Other(anyhow!( "Failed to get LeRobot dataset column data for: {:?}", field.name() @@ -321,6 +325,7 @@ fn load_scalar( fn make_scalar_batch_entity_chunks( field: &Field, + feature: &Feature, timelines: &IntMap, data: &FixedSizeListArray, ) -> Result, DataLoaderError> { @@ -330,7 +335,13 @@ fn make_scalar_batch_entity_chunks( let mut chunks = Vec::with_capacity(num_elements); for idx in 0..num_elements { - let entity_path = format!("{}/{idx}", field.name()); + let name = feature + .names + .as_ref() + .and_then(|names| names.name_for_index(idx).cloned()) + .unwrap_or(format!("{idx}")); + + let entity_path = format!("{}/{name}", field.name()); chunks.push(make_scalar_entity_chunk( entity_path.into(), timelines,