-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbuild.py
116 lines (82 loc) · 3.42 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import logging
import sys
import zipfile
from pathlib import Path
from typing import Callable, List
from yupi import Trajectory
from yupi.core import JSONSerializer
import config
from utils.utils import _get_path
RECIPIES_DIR = Path("./recipes")
def _read_cache_version(name: str) -> int:
ds_dir = _get_path(config.DS_DIR, name)
yupi_data_json = ds_dir / "yupi_data.json"
if not yupi_data_json.exists():
return -1
yupi_data = _load_yupi_data(yupi_data_json)
return yupi_data.get("version", -1)
def _cache_up_to_date(name: str, version: int) -> bool:
# Check if rebuild is required
cache_version = _read_cache_version(name)
rebuild = version > cache_version
return not rebuild
def _load_yupi_data(yupi_data_json: Path):
with open(yupi_data_json, "r", encoding="utf-8") as md_file:
return json.load(md_file)
def _save_yupi_data(yupi_data: dict, path: Path):
with open(path, "w", encoding="utf-8") as md_file:
json.dump(yupi_data, md_file, ensure_ascii=False, indent=4)
def _update_labels(trajs: List[Trajectory]):
for i, traj in enumerate(trajs):
traj.traj_id = str(i)
return trajs
def _build_recipe(output_dir: Path, name: str, version: int, build_func: Callable):
trajs, labels = build_func()
if len(trajs) != len(labels):
raise ValueError(
f"Number of trajectories and labels must be equal. "
f"Got {len(trajs)} trajectories and {len(labels)} labels."
)
if not all(len(traj) > 1 for traj in trajs):
raise ValueError("All trajectories must have at least 2 points.")
trajs = _update_labels(trajs)
ds_dir = _get_path(config.DS_DIR, name)
ds_dir.mkdir(parents=True, exist_ok=True)
json_trajs = [JSONSerializer.to_json(traj) for traj in trajs]
yupi_data = {"version": version, "trajs": json_trajs, "labels": labels}
logging.info("Saving yupify trajectories for %s dataset", name)
data_path = ds_dir / "yupi_data.json"
_save_yupi_data(yupi_data, data_path)
# Compress to output dir
output_zip = output_dir / f"{name}.zip"
output_zip.parent.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(output_zip, "w", zipfile.ZIP_DEFLATED) as zip_ref:
zip_ref.write(filename=data_path, arcname=f"{name}.json")
def build_recipe(output_dir: Path, name: str, version: int, build_func: Callable):
if _cache_up_to_date(name, version):
logging.info("Dataset '%s' is up to date (v%s)", name, version)
return
_build_recipe(output_dir, name, version, build_func)
def process_recipe(output_dir: Path, recipe_py_path: Path):
# import NAME, VERSION and build from .py
module_name = recipe_py_path.name.replace(".py", "")
recipe = __import__(
f"recipes.{module_name}", globals(), locals(), ["NAME", "VERSION", "build"], 0
)
name = recipe.NAME
version = recipe.VERSION
build_func = recipe.build
build_recipe(output_dir, name, version, build_func)
def main():
only_recipies = sys.argv[1].split(",") if len(sys.argv) > 1 else None
output_dir = Path("./builds")
for dataset_recipe in RECIPIES_DIR.glob("[!_]*.py"):
if dataset_recipe.stem not in only_recipies:
continue
try:
process_recipe(output_dir, dataset_recipe)
except AttributeError:
logging.error("Recipe '%s' has missing fields", str(dataset_recipe))
if __name__ == "__main__":
main()