Skip to content

Commit

Permalink
HITL - Data collection (#1967)
Browse files Browse the repository at this point in the history
* Add session management.

* Formatting changes.

* Add clarifications to episode resolution.

* Document temporary hack to check for client-side loading status.

* Add session recorder, ui events and data upload.

* Change path handling in session upload code.
  • Loading branch information
0mdc authored May 20, 2024
1 parent 11ec7bc commit 409d0c3
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 20 deletions.
63 changes: 61 additions & 2 deletions examples/hitl/rearrange_v2/app_state_end_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
from typing import Optional

from app_data import AppData
from app_state_base import AppStateBase
from app_states import create_app_state_reset
from s3_upload import (
generate_unique_session_id,
make_s3_filename,
upload_file_to_s3,
)
from session import Session
from util import get_top_down_view

from habitat_hitl.app_states.app_service import AppService
from habitat_hitl.core.serialize_utils import save_as_json_gzip
from habitat_hitl.core.user_mask import Mask

# Duration of the end session message, before users are kicked.
Expand Down Expand Up @@ -55,5 +63,56 @@ def sim_update(self, dt: float, post_sim_update_dict):
self._elapsed_time += dt

def _end_session(self):
# TODO: Data collection.
pass
session = self._session
if session is None:
print("Null session. Skipping S3 upload.")
return

# Finalize session.
if self._session.error == "":
session.success = True
session.session_recorder.end_session(self._session.error)

# Get data collection parameters.
try:
config = self._app_service.config
data_collection_config = config.rearrange_v2.data_collection
s3_path = data_collection_config.s3_path
s3_subdir = "complete" if session.success else "incomplete"
s3_path = os.path.join(s3_path, s3_subdir)

# Use the port as a discriminator for when there are multiple concurrent servers.
output_folder_suffix = str(config.habitat_hitl.networking.port)
output_folder = f"output_{output_folder_suffix}"

output_file_name = data_collection_config.output_file_name
output_file = f"{output_file_name}.json.gz"

except Exception as e:
print(f"Invalid data collection config. Skipping S3 upload. {e}")
return

# Delete previous output directory
if os.path.exists(output_folder):
shutil.rmtree(output_folder)

# Create new output directory
os.makedirs(output_folder)
json_path = os.path.join(output_folder, output_file)
save_as_json_gzip(session.session_recorder, json_path)

# Generate unique session ID
session_id = generate_unique_session_id(
session.episode_ids, session.connection_records
)

# Upload output directory
orig_file_names = [
f
for f in os.listdir(output_folder)
if os.path.isfile(os.path.join(output_folder, f))
]
for orig_file_name in orig_file_names:
local_file_path = os.path.join(output_folder, orig_file_name)
s3_file_name = make_s3_filename(session_id, orig_file_name)
upload_file_to_s3(local_file_path, s3_file_name, s3_path)
5 changes: 5 additions & 0 deletions examples/hitl/rearrange_v2/config/language_rearrange.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ habitat:
dataset:
type: "CollaborationDataset-v0"
data_path: data/datasets/hssd/llm_rearrange/v2/60scenes_dataset_776eps_with_eval.json.gz

rearrange_v2:
data_collection:
s3_path: "Placeholder/"
output_file_name: "session"
119 changes: 105 additions & 14 deletions examples/hitl/rearrange_v2/rearrange_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import magnum as mn
import numpy as np
Expand Down Expand Up @@ -41,10 +41,14 @@
PIP_VIEWPORT_ID = 0 # ID of the picture-in-picture viewport that shows other agent's perspective.


class DataLogger:
def __init__(self, app_service: AppService):
class FrameRecorder:
def __init__(
self, app_service: AppService, app_data: AppData, world: World
):
self._app_service = app_service
self._app_data = app_data
self._sim = app_service.sim
self._world = world

def get_num_agents(self):
return len(self._sim.agents_mgr._all_agent_data)
Expand Down Expand Up @@ -87,15 +91,29 @@ def get_objects_state(self):
)
return object_states

def record_state(self, task_completed: bool = False):
agent_states = self.get_agents_state()
object_states = self.get_objects_state()

self._app_service.step_recorder.record("agent_states", agent_states)
self._app_service.step_recorder.record("object_states", object_states)
self._app_service.step_recorder.record(
"task_completed", task_completed
)
def record_state(
self, elapsed_time: float, user_data: List[UserData]
) -> Dict[str, Any]:
data: Dict[str, Any] = {
"t": elapsed_time,
"users": [],
"object_states": self.get_objects_state(),
"agent_states": self.get_agents_state(),
}

for user_index in range(len(user_data)):
u = user_data[user_index]
user_data_dict = {
"task_completed": u.episode_finished,
"task_succeeded": u.episode_success,
"camera_transform": u.cam_transform,
"held_object": u.ui._held_object_id,
"hovered_object": u.ui._hover_selection.object_id,
"events": u.pop_ui_events(),
}
data["users"].append(user_data_dict)

return data


class UserData:
Expand Down Expand Up @@ -124,6 +142,9 @@ def __init__(
self.task_instruction = ""
self.pip_initialized = False

# Events for data collection.
self.ui_events: List[Dict[str, Any]] = []

# If in remote mode, get the remote input. Else get the server (local) input.
self.gui_input = (
app_service.remote_client_state.get_gui_input(user_index)
Expand All @@ -149,6 +170,12 @@ def __init__(
camera_helper=self.camera_helper,
)

# Register UI callbacks
self.ui.on_pick.registerCallback(self._on_pick)
self.ui.on_place.registerCallback(self._on_place)
self.ui.on_open.registerCallback(self._on_open)
self.ui.on_close.registerCallback(self._on_close)

# HACK: Work around GuiController input.
# TODO: Communicate to the controller via action hints.
gui_agent_controller._gui_input = self.gui_input
Expand Down Expand Up @@ -243,6 +270,11 @@ def draw_pip_viewport(self, pip_user_data: UserData):
destination_mask=Mask.from_index(self.user_index),
)

def pop_ui_events(self) -> List[Dict[str, Any]]:
events = list(self.ui_events)
self.ui_events.clear()
return events

def _get_camera_lookat_pos(self) -> mn.Vector3:
agent_root = get_agent_art_obj_transform(
self.app_service.sim,
Expand All @@ -255,6 +287,43 @@ def _get_camera_lookat_pos(self) -> mn.Vector3:
def _is_user_idle_this_frame(self) -> bool:
return not self.gui_input.get_any_input()

def _on_pick(self, e: UI.PickEventData):
self.ui_events.append(
{
"type": "pick",
"obj_handle": e.object_handle,
"obj_id": e.object_id,
}
)

def _on_place(self, e: UI.PlaceEventData):
self.ui_events.append(
{
"type": "place",
"obj_handle": e.object_handle,
"obj_id": e.object_id,
"receptacle_id": e.receptacle_id,
}
)

def _on_open(self, e: UI.OpenEventData):
self.ui_events.append(
{
"type": "open",
"obj_handle": e.object_handle,
"obj_id": e.object_id,
}
)

def _on_close(self, e: UI.CloseEventData):
self.ui_events.append(
{
"type": "close",
"obj_handle": e.object_handle,
"obj_id": e.object_id,
}
)


class AppStateRearrangeV2(AppStateBase):
"""
Expand Down Expand Up @@ -295,6 +364,10 @@ def __init__(
)
)

self._frame_recorder = FrameRecorder(
app_service, app_data, self._world
)

# Reset the environment immediately.
self.on_environment_reset(None)

Expand Down Expand Up @@ -322,9 +395,24 @@ def on_enter(self):
user_index
].gui_agent_controller._agent_idx

episode = self._app_service.episode_helper.current_episode
self._session.session_recorder.start_episode(
episode.episode_id,
episode.scene_id,
episode.scene_dataset_config,
user_index_to_agent_index_map,
)

def on_exit(self):
super().on_exit()

episode_success = all(
self._user_data[user_index].episode_success
for user_index in range(self._app_data.max_user_count)
)

self._session.session_recorder.end_episode(episode_success)

def _is_episode_finished(self) -> bool:
"""
Determines whether all users have finished their tasks.
Expand Down Expand Up @@ -501,9 +589,12 @@ def sim_update(self, dt: float, post_sim_update_dict):

# Collect data.
self._elapsed_time += dt
# TODO: Always record with non-human agent.
if self._is_any_user_active():
# TODO: Add data collection.
pass
frame_data = self._frame_recorder.record_state(
self._elapsed_time, self._user_data
)
self._session.session_recorder.record_frame(frame_data)

def _is_any_user_active(self) -> bool:
return any(
Expand Down
11 changes: 7 additions & 4 deletions examples/hitl/rearrange_v2/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from typing import Any, Dict, List

from session_recorder import SessionRecorder

from habitat_hitl.core.types import ConnectionRecord


Expand All @@ -25,8 +27,9 @@ def __init__(
self.episode_ids = episode_ids
self.current_episode_index = 0
self.connection_records = connection_records
self.error = "" # Use this to display error that causes termination

# Use the port as a discriminator for when there are multiple concurrent servers.
output_folder_suffix = str(config.habitat_hitl.networking.port)
self.output_folder = f"output_{output_folder_suffix}"
self.session_recorder = SessionRecorder(
config, connection_records, episode_ids
)

self.error = "" # Use this to display error that causes termination
85 changes: 85 additions & 0 deletions examples/hitl/rearrange_v2/session_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List

from util import timestamp

from habitat_hitl.core.types import ConnectionRecord


class SessionRecorder:
def __init__(
self,
config: Dict[str, Any],
connection_records: Dict[int, ConnectionRecord],
episode_ids: List[str],
):
self.data = {
"episode_ids": episode_ids,
"completed": False,
"error": "",
"start_timestamp": timestamp(),
"end_timestamp": timestamp(),
"config": config,
"frame_count": 0,
"users": [],
"episodes": [],
}

for user_index in range(len(connection_records)):
self.data["users"].append(
{
"user_index": user_index,
"connection_record": connection_records[user_index],
}
)

def end_session(self, error: str):
self.data["end_timestamp"] = timestamp()
self.data["completed"] = True
self.data["error"] = error

def start_episode(
self,
episode_id: str,
scene_id: str,
dataset: str,
user_index_to_agent_index_map: Dict[int, int],
):
self.data["episodes"].append(
{
"episode_id": episode_id,
"scene_id": scene_id,
"start_timestamp": timestamp(),
"end_timestamp": timestamp(),
"completed": False,
"success": False,
"frame_count": 0,
"dataset": dataset,
"user_index_to_agent_index_map": user_index_to_agent_index_map,
"frames": [],
}
)

def end_episode(
self,
success: bool,
):
self.data["episodes"][-1]["end_timestamp"] = timestamp()
self.data["episodes"][-1]["success"] = success
self.data["episodes"][-1]["completed"] = True

def record_frame(
self,
frame_data: Dict[str, Any],
):
self.data["end_timestamp"] = timestamp()
self.data["frame_count"] += 1

self.data["episodes"][-1]["end_timestamp"] = timestamp()
self.data["episodes"][-1]["frame_count"] += 1
self.data["episodes"][-1]["frames"].append(frame_data)
Loading

0 comments on commit 409d0c3

Please sign in to comment.