Skip to content

Commit 974028b

Browse files
Organize test folders (#856)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
1 parent a36ed39 commit 974028b

File tree

79 files changed

+63
-106
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+63
-106
lines changed

.dockerignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ pip-log.txt
7373
pip-delete-this-directory.txt
7474

7575
# Unit test / coverage reports
76-
!tests/data
76+
!tests/artifacts
7777
htmlcov/
7878
.tox/
7979
.nox/

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ pip-log.txt
7878
pip-delete-this-directory.txt
7979

8080
# Unit test / coverage reports
81-
!tests/data
81+
!tests/artifacts
8282
htmlcov/
8383
.tox/
8484
.nox/

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
exclude: ^(tests/data)
15+
exclude: "tests/artifacts/.*\\.safetensors$"
1616
default_language_version:
1717
python: python3.10
1818
repos:

CONTRIBUTING.md

+1-1

lerobot/common/robot_devices/cameras/intelrealsense.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
4848
connected to the computer.
4949
"""
5050
if mock:
51-
import tests.mock_pyrealsense2 as rs
51+
import tests.cameras.mock_pyrealsense2 as rs
5252
else:
5353
import pyrealsense2 as rs
5454

@@ -100,7 +100,7 @@ def save_images_from_cameras(
100100
serial_numbers = [cam["serial_number"] for cam in camera_infos]
101101

102102
if mock:
103-
import tests.mock_cv2 as cv2
103+
import tests.cameras.mock_cv2 as cv2
104104
else:
105105
import cv2
106106

@@ -253,7 +253,7 @@ def __init__(
253253
self.logs = {}
254254

255255
if self.mock:
256-
import tests.mock_cv2 as cv2
256+
import tests.cameras.mock_cv2 as cv2
257257
else:
258258
import cv2
259259

@@ -287,7 +287,7 @@ def connect(self):
287287
)
288288

289289
if self.mock:
290-
import tests.mock_pyrealsense2 as rs
290+
import tests.cameras.mock_pyrealsense2 as rs
291291
else:
292292
import pyrealsense2 as rs
293293

@@ -375,7 +375,7 @@ def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndar
375375
)
376376

377377
if self.mock:
378-
import tests.mock_cv2 as cv2
378+
import tests.cameras.mock_cv2 as cv2
379379
else:
380380
import cv2
381381

lerobot/common/robot_devices/cameras/opencv.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _find_cameras(
8080
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
8181
) -> list[int | str]:
8282
if mock:
83-
import tests.mock_cv2 as cv2
83+
import tests.cameras.mock_cv2 as cv2
8484
else:
8585
import cv2
8686

@@ -269,7 +269,7 @@ def __init__(self, config: OpenCVCameraConfig):
269269
self.logs = {}
270270

271271
if self.mock:
272-
import tests.mock_cv2 as cv2
272+
import tests.cameras.mock_cv2 as cv2
273273
else:
274274
import cv2
275275

@@ -286,7 +286,7 @@ def connect(self):
286286
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
287287

288288
if self.mock:
289-
import tests.mock_cv2 as cv2
289+
import tests.cameras.mock_cv2 as cv2
290290
else:
291291
import cv2
292292

@@ -398,7 +398,7 @@ def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
398398
# so we convert the image color from BGR to RGB.
399399
if requested_color_mode == "rgb":
400400
if self.mock:
401-
import tests.mock_cv2 as cv2
401+
import tests.cameras.mock_cv2 as cv2
402402
else:
403403
import cv2
404404

lerobot/common/robot_devices/motors/dynamixel.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def connect(self):
332332
)
333333

334334
if self.mock:
335-
import tests.mock_dynamixel_sdk as dxl
335+
import tests.motors.mock_dynamixel_sdk as dxl
336336
else:
337337
import dynamixel_sdk as dxl
338338

@@ -356,7 +356,7 @@ def connect(self):
356356

357357
def reconnect(self):
358358
if self.mock:
359-
import tests.mock_dynamixel_sdk as dxl
359+
import tests.motors.mock_dynamixel_sdk as dxl
360360
else:
361361
import dynamixel_sdk as dxl
362362

@@ -646,7 +646,7 @@ def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] |
646646

647647
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
648648
if self.mock:
649-
import tests.mock_dynamixel_sdk as dxl
649+
import tests.motors.mock_dynamixel_sdk as dxl
650650
else:
651651
import dynamixel_sdk as dxl
652652

@@ -691,7 +691,7 @@ def read(self, data_name, motor_names: str | list[str] | None = None):
691691
start_time = time.perf_counter()
692692

693693
if self.mock:
694-
import tests.mock_dynamixel_sdk as dxl
694+
import tests.motors.mock_dynamixel_sdk as dxl
695695
else:
696696
import dynamixel_sdk as dxl
697697

@@ -757,7 +757,7 @@ def read(self, data_name, motor_names: str | list[str] | None = None):
757757

758758
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
759759
if self.mock:
760-
import tests.mock_dynamixel_sdk as dxl
760+
import tests.motors.mock_dynamixel_sdk as dxl
761761
else:
762762
import dynamixel_sdk as dxl
763763

@@ -793,7 +793,7 @@ def write(self, data_name, values: int | float | np.ndarray, motor_names: str |
793793
start_time = time.perf_counter()
794794

795795
if self.mock:
796-
import tests.mock_dynamixel_sdk as dxl
796+
import tests.motors.mock_dynamixel_sdk as dxl
797797
else:
798798
import dynamixel_sdk as dxl
799799

lerobot/common/robot_devices/motors/feetech.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def connect(self):
313313
)
314314

315315
if self.mock:
316-
import tests.mock_scservo_sdk as scs
316+
import tests.motors.mock_scservo_sdk as scs
317317
else:
318318
import scservo_sdk as scs
319319

@@ -337,7 +337,7 @@ def connect(self):
337337

338338
def reconnect(self):
339339
if self.mock:
340-
import tests.mock_scservo_sdk as scs
340+
import tests.motors.mock_scservo_sdk as scs
341341
else:
342342
import scservo_sdk as scs
343343

@@ -664,7 +664,7 @@ def avoid_rotation_reset(self, values, motor_names, data_name):
664664

665665
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
666666
if self.mock:
667-
import tests.mock_scservo_sdk as scs
667+
import tests.motors.mock_scservo_sdk as scs
668668
else:
669669
import scservo_sdk as scs
670670

@@ -702,7 +702,7 @@ def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_
702702

703703
def read(self, data_name, motor_names: str | list[str] | None = None):
704704
if self.mock:
705-
import tests.mock_scservo_sdk as scs
705+
import tests.motors.mock_scservo_sdk as scs
706706
else:
707707
import scservo_sdk as scs
708708

@@ -782,7 +782,7 @@ def read(self, data_name, motor_names: str | list[str] | None = None):
782782

783783
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
784784
if self.mock:
785-
import tests.mock_scservo_sdk as scs
785+
import tests.motors.mock_scservo_sdk as scs
786786
else:
787787
import scservo_sdk as scs
788788

@@ -818,7 +818,7 @@ def write(self, data_name, values: int | float | np.ndarray, motor_names: str |
818818
start_time = time.perf_counter()
819819

820820
if self.mock:
821-
import tests.mock_scservo_sdk as scs
821+
import tests.motors.mock_scservo_sdk as scs
822822
else:
823823
import scservo_sdk as scs
824824

pyproject.toml

+1-24
Original file line numberDiff line numberDiff line change
@@ -102,30 +102,7 @@ requires-poetry = ">=2.1"
102102
[tool.ruff]
103103
line-length = 110
104104
target-version = "py310"
105-
exclude = [
106-
"tests/data",
107-
".bzr",
108-
".direnv",
109-
".eggs",
110-
".git",
111-
".git-rewrite",
112-
".hg",
113-
".mypy_cache",
114-
".nox",
115-
".pants.d",
116-
".pytype",
117-
".ruff_cache",
118-
".svn",
119-
".tox",
120-
".venv",
121-
"__pypackages__",
122-
"_build",
123-
"buck-out",
124-
"build",
125-
"dist",
126-
"node_modules",
127-
"venv",
128-
]
105+
exclude = ["tests/artifacts/**/*.safetensors"]
129106

130107
[tool.ruff.lint]
131108
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]

tests/scripts/save_dataset_to_safetensors.py tests/artifacts/datasets/save_dataset_to_safetensors.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
2424
2525
Example usage:
26-
`python tests/scripts/save_dataset_to_safetensors.py`
26+
`python tests/artifacts/datasets/save_dataset_to_safetensors.py`
2727
"""
2828

2929
import shutil
@@ -88,4 +88,4 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
8888
"lerobot/nyu_franka_play_dataset",
8989
"lerobot/cmu_stretch",
9090
]:
91-
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
91+
save_dataset_to_safetensors("tests/artifacts/datasets", repo_id=dataset)

tests/scripts/save_image_transforms_to_safetensors.py tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from lerobot.common.utils.random_utils import seeded_context
2929

30-
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
30+
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
3131
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
3232

3333

tests/scripts/save_policy_to_safetensors.py tests/artifacts/policies/save_policy_to_safetensors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,5 +141,5 @@ def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: s
141141
raise RuntimeError("No policies were provided!")
142142
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
143143
ds_name = ds_repo_id.split("/")[-1]
144-
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
144+
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
145145
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
File renamed without changes.
File renamed without changes.

tests/test_cameras.py tests/cameras/test_cameras.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_camera(request, camera_type, mock):
146146
camera.connect()
147147

148148
if mock:
149-
import tests.mock_cv2 as cv2
149+
import tests.cameras.mock_cv2 as cv2
150150
else:
151151
import cv2
152152

File renamed without changes.

tests/test_datasets.py tests/datasets/test_datasets.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,12 @@ def test_flatten_unflatten_dict():
473473
)
474474
@require_x86_64_kernel
475475
def test_backward_compatibility(repo_id):
476-
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
476+
"""The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`."""
477477

478478
# TODO(rcadene, aliberts): remove dataset download
479479
dataset = LeRobotDataset(repo_id, episodes=[0])
480480

481-
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
481+
test_dir = Path("tests/artifacts/datasets") / repo_id
482482

483483
def load_and_compare(i):
484484
new_frame = dataset[i] # noqa: B023
File renamed without changes.

tests/test_image_transforms.py tests/datasets/test_image_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
save_all_transforms,
3434
save_each_transform,
3535
)
36-
from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
36+
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
3737
from tests.utils import require_x86_64_kernel
3838

3939

File renamed without changes.
File renamed without changes.
File renamed without changes.

tests/test_utils.py tests/datasets/test_utils.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/usr/bin/env python
2+
13
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
24
#
35
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,13 +13,32 @@
1113
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1214
# See the License for the specific language governing permissions and
1315
# limitations under the License.
16+
1417
import torch
1518
from datasets import Dataset
19+
from huggingface_hub import DatasetCard
1620

1721
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
18-
from lerobot.common.datasets.utils import (
19-
hf_transform_to_torch,
20-
)
22+
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
23+
24+
25+
def test_default_parameters():
26+
card = create_lerobot_dataset_card()
27+
assert isinstance(card, DatasetCard)
28+
assert card.data.tags == ["LeRobot"]
29+
assert card.data.task_categories == ["robotics"]
30+
assert card.data.configs == [
31+
{
32+
"config_name": "default",
33+
"data_files": "data/*/*.parquet",
34+
}
35+
]
36+
37+
38+
def test_with_tags():
39+
tags = ["tag1", "tag2"]
40+
card = create_lerobot_dataset_card(tags=tags)
41+
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
2142

2243

2344
def test_calculate_episode_data_index():
File renamed without changes.

tests/test_envs.py tests/envs/test_envs.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
import lerobot
2424
from lerobot.common.envs.factory import make_env, make_env_config
2525
from lerobot.common.envs.utils import preprocess_observation
26-
27-
from .utils import require_env
26+
from tests.utils import require_env
2827

2928
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
3029

File renamed without changes.

0 commit comments

Comments
 (0)