Skip to content

Commit b97632a

Browse files
committed
Add script for identifying robot at multiple gripper oppenings
1 parent 090e2c5 commit b97632a

File tree

1 file changed

+271
-0
lines changed

1 file changed

+271
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""Identifies the robot parameters at multiple gripper openings.
2+
3+
The script will identify the robot parameters at multiple gripper openings. It will
4+
load the joint data from the given path, process the joint data, and then identify the
5+
robot parameters.
6+
7+
8+
The expected input directory structure is as follows:
9+
10+
<joint_data_path>/gripper_position_<gripper_position>/
11+
run_<run_idx>/
12+
joint_positions.npy
13+
joint_torques.npy
14+
sample_times_s.npy
15+
16+
The identified parameters will be saved into
17+
<joint_data_path>/gripper_position_<gripper_position>/identified_robot_params.npy.
18+
"""
19+
20+
import argparse
21+
import logging
22+
23+
from pathlib import Path
24+
25+
import numpy as np
26+
27+
from robot_payload_id.data import (
28+
compute_base_param_mapping,
29+
extract_numeric_data_matrix_autodiff,
30+
)
31+
from robot_payload_id.environment import create_arm
32+
from robot_payload_id.optimization import solve_inertial_param_sdp
33+
from robot_payload_id.utils import (
34+
ArmPlantComponents,
35+
JointData,
36+
get_plant_joint_params,
37+
process_joint_data,
38+
)
39+
40+
41+
def main():
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument(
44+
"--num_data_points",
45+
type=int,
46+
default=10000,
47+
help="Number of data points to use.",
48+
)
49+
parser.add_argument(
50+
"--joint_data_path",
51+
type=Path,
52+
help="See main file docstring.",
53+
)
54+
parser.add_argument(
55+
"--regularization_weight",
56+
type=float,
57+
default=1e-3,
58+
help="The regularization weight.",
59+
)
60+
parser.add_argument(
61+
"--pos_order",
62+
type=int,
63+
default=10,
64+
help="The order of the filter for the joint positions. Only used if "
65+
+ "`--process_joint_data` is set.",
66+
)
67+
parser.add_argument(
68+
"--pos_cutoff_freq_hz",
69+
type=float,
70+
default=60.0,
71+
help="The cutoff frequency of the filter for the joint positions. Only used if "
72+
+ "`--process_joint_data` is set.",
73+
)
74+
parser.add_argument(
75+
"--vel_order",
76+
type=int,
77+
default=10,
78+
help="The order of the filter for the joint velocities. Only used if "
79+
+ "`--process_joint_data` is set.",
80+
)
81+
parser.add_argument(
82+
"--vel_cutoff_freq_hz",
83+
type=float,
84+
default=10.0,
85+
help="The cutoff frequency of the filter for the joint velocities. Only used if "
86+
+ "`--process_joint_data` is set.",
87+
)
88+
parser.add_argument(
89+
"--acc_order",
90+
type=int,
91+
default=10,
92+
help="The order of the filter for the joint accelerations. Only used if "
93+
+ "`--process_joint_data` is set.",
94+
)
95+
parser.add_argument(
96+
"--acc_cutoff_freq_hz",
97+
type=float,
98+
default=30.0,
99+
help="The cutoff frequency of the filter for the joint accelerations. Only used "
100+
+ "if `--process_joint_data` is set.",
101+
)
102+
parser.add_argument(
103+
"--torque_order",
104+
type=int,
105+
default=10,
106+
help="The order of the filter for the joint torques. Only used if "
107+
+ "`--process_joint_data` is set.",
108+
)
109+
parser.add_argument(
110+
"--torque_cutoff_freq_hz",
111+
type=float,
112+
default=10.0,
113+
help="The cutoff frequency of the filter for the joint torques. Only used if "
114+
+ "`--process_joint_data` is set.",
115+
)
116+
parser.add_argument(
117+
"--log_level",
118+
type=str,
119+
default="INFO",
120+
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
121+
help="Log level.",
122+
)
123+
124+
args = parser.parse_args()
125+
regularization_weight = args.regularization_weight
126+
num_endpoints_to_remove = args.num_endpoints_to_remove
127+
compute_velocities = not args.not_compute_velocities
128+
filter_positions = not args.not_filter_positions
129+
pos_filter_order = args.pos_order
130+
pos_cutoff_freq_hz = args.pos_cutoff_freq_hz
131+
vel_filter_order = args.vel_order
132+
vel_cutoff_freq_hz = args.vel_cutoff_freq_hz
133+
acc_filter_order = args.acc_order
134+
acc_cutoff_freq_hz = args.acc_cutoff_freq_hz
135+
torque_filter_order = args.torque_order
136+
torque_cutoff_freq_hz = args.torque_cutoff_freq_hz
137+
138+
logging.basicConfig(level=args.log_level)
139+
140+
# Create arm
141+
num_joints = 7
142+
# NOTE: This model must not have a payload attached. Otherwise, the w0 term will be
143+
# wrong and include the payload inertia.
144+
model_path = "./models/iiwa.dmd.yaml"
145+
arm_components = create_arm(
146+
arm_file_path=model_path, num_joints=num_joints, time_step=0.0
147+
)
148+
arm_plant_components = ArmPlantComponents(
149+
plant=arm_components.plant,
150+
plant_context=arm_components.plant.CreateDefaultContext(),
151+
)
152+
153+
subdirs = list(args.joint_data_path.iterdir())
154+
for subdir in subdirs:
155+
logging.info(f"Processing {subdir}")
156+
157+
# Load all runs and average data over time.
158+
run_dirs = list(subdir.iterdir())
159+
joint_datas = []
160+
for run_dir in run_dirs:
161+
joint_data = JointData.load_from_disk_allow_missing(run_dir)
162+
joint_datas.append(joint_data)
163+
joint_data = JointData.average_joint_datas(joint_datas)
164+
165+
# Process joint data
166+
joint_data = process_joint_data(
167+
joint_data=joint_data,
168+
num_endpoints_to_remove=num_endpoints_to_remove,
169+
compute_velocities=compute_velocities,
170+
filter_positions=filter_positions,
171+
pos_filter_order=pos_filter_order,
172+
pos_cutoff_freq_hz=pos_cutoff_freq_hz,
173+
vel_filter_order=vel_filter_order,
174+
vel_cutoff_freq_hz=vel_cutoff_freq_hz,
175+
acc_filter_order=acc_filter_order,
176+
acc_cutoff_freq_hz=acc_cutoff_freq_hz,
177+
torque_filter_order=torque_filter_order,
178+
torque_cutoff_freq_hz=torque_cutoff_freq_hz,
179+
)
180+
181+
# Generate data matrix
182+
(
183+
W_data_raw,
184+
w0_data,
185+
_,
186+
) = extract_numeric_data_matrix_autodiff(
187+
plant_components=arm_plant_components,
188+
joint_data=joint_data,
189+
add_rotor_inertia=False,
190+
add_reflected_inertia=True,
191+
add_viscous_friction=True,
192+
add_dynamic_dry_friction=True,
193+
payload_only=False,
194+
)
195+
tau_data = joint_data.joint_torques.flatten()
196+
# Transform from affine `tau = W * params + w0` into linear `(tau - w0) = W * params`
197+
tau_data -= w0_data
198+
199+
# Compute the base parameter mapping
200+
num_random_points = 2000
201+
joint_data_random = JointData(
202+
joint_positions=np.random.rand(num_random_points, num_joints) - 0.5,
203+
joint_velocities=np.random.rand(num_random_points, num_joints) - 0.5,
204+
joint_accelerations=np.random.rand(num_random_points, num_joints) - 0.5,
205+
joint_torques=np.zeros((num_random_points, num_joints)),
206+
sample_times_s=np.zeros(num_random_points),
207+
)
208+
W_data_random, _, _ = extract_numeric_data_matrix_autodiff(
209+
plant_components=arm_plant_components,
210+
joint_data=joint_data_random,
211+
add_rotor_inertia=False,
212+
add_reflected_inertia=True,
213+
add_viscous_friction=True,
214+
add_dynamic_dry_friction=True,
215+
payload_only=False,
216+
)
217+
base_param_mapping = compute_base_param_mapping(W_data_random)
218+
logging.info(
219+
f"{base_param_mapping.shape[1]} out of {base_param_mapping.shape[0]} "
220+
+ "parameters are identifiable."
221+
)
222+
223+
# Remove structurally unidentifiable columns to prevent SolutionResult.kUnbounded
224+
W_data = np.empty((W_data_raw.shape[0], base_param_mapping.shape[1]))
225+
for i in range(args.num_data_points):
226+
W_data[i * num_joints : (i + 1) * num_joints, :] = (
227+
W_data_raw[i * num_joints : (i + 1) * num_joints, :]
228+
@ base_param_mapping
229+
)
230+
231+
# Construct initial parameter guess
232+
params_guess = get_plant_joint_params(
233+
arm_components.plant,
234+
arm_components.plant.CreateDefaultContext(),
235+
add_rotor_inertia=False,
236+
add_reflected_inertia=True,
237+
add_viscous_friction=True,
238+
add_dynamic_dry_friction=True,
239+
payload_only=False,
240+
)
241+
242+
(_, result, variable_names, variable_vec, _) = solve_inertial_param_sdp(
243+
num_links=num_joints,
244+
W_data=W_data,
245+
tau_data=tau_data,
246+
base_param_mapping=base_param_mapping,
247+
regularization_weight=regularization_weight,
248+
params_guess=params_guess,
249+
use_euclidean_regularization=False,
250+
add_rotor_inertia=False,
251+
add_reflected_inertia=True,
252+
add_viscous_friction=True,
253+
add_dynamic_dry_friction=True,
254+
payload_only=False,
255+
)
256+
if result.is_success():
257+
final_cost = result.get_optimal_cost()
258+
logging.info(f"SDP cost: {final_cost}")
259+
var_sol_dict = dict(zip(variable_names, result.GetSolution(variable_vec)))
260+
261+
out_path = subdir / "identified_robot_params.npy"
262+
logging.info(f"Saving parameters to {out_path}")
263+
np.save(out_path, var_sol_dict)
264+
else:
265+
logging.warning("Failed to solve inertial parameter SDP!")
266+
logging.info(f"Solution result:\n{result.get_solution_result()}")
267+
logging.info(f"Solver details:\n{result.get_solver_details()}")
268+
269+
270+
if __name__ == "__main__":
271+
main()

0 commit comments

Comments
 (0)