-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
243 lines (198 loc) · 6.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""
Utility functions. Some are adapted from other repositories.
@author: Nanbo Li
"""
import os
import json
import pickle
import h5py
import random
import numpy as np
import torch
from itertools import repeat
from collections import OrderedDict
import pdb
# ------ General utilities ------
def ensure_dir(dirname):
if not os.path.isdir(dirname):
os.makedirs(dirname, exist_ok=True)
def read_json(fname):
with open(fname, "r") as read_file:
return json.load(read_file)
def write_json(content, fname):
with open(fname, "w") as write_file:
json.dump(content, write_file)
def read_h5py(fname):
return h5py.File(fname, 'r')
def write_pick(content, fname):
pickle_out = open(fname, "wb")
pickle.dump(content, pickle_out)
pickle_out.close()
def load_pickle(fname):
pickle_in = open(fname, "rb")
return pickle.load(pickle_in)
# ------ Training utitiles ------
def inf_loop(data_loader):
"""wrapper function for endless data loader"""
for loader in repeat(data_loader):
yield from loader
def set_random_seed(seed):
"""set random seed"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def load_trained_mp(ckpt_path):
state_dict = torch.load(ckpt_path, map_location='cpu')['model']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k[:7] == 'module.':
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
return new_state_dict
# ------ numpy/pytorch utitiles ------
def numpify(tensor):
return tensor.detach().cpu().numpy()
def iou_pair(outputs: torch.Tensor, labels: torch.Tensor):
assert outputs.dim() == 2, 'expected [H, W], got {}'.format(outputs.size())
assert outputs.dim() == 2, 'expected [H, W], got {}'.format(labels.size())
assert outputs.dtype == torch.float32
assert labels.dtype == torch.float32
intersection = (outputs > 0.9) & (labels > 0.9)
union = (outputs > 0.9) | (labels > 0.9)
# union = (outputs + labels > 0.8).float().sum(()) # Will be zzero if both are 0
iou = (intersection.float().sum() + 1e-6) / (union.float().sum() + 1e-6)
iou = iou if iou > 0.05 else iou * 0.0
return iou
def iou_scn(x, gt, threshold=1.0):
"""
outputs: [K, H, W]
labels: [K, H, W]
"""
assert x.dim() == 3
assert gt.dim() == 3
len_x = x.size(0)
len_gt = gt.size(0)
assert len_x == len_gt
M = []
for k in range(len_x):
M.append(iou_pair(x[k], gt[k]))
iou_scores = torch.stack(M, dim=0)
# iou_scores = torch.stack(M, dim=0)
if threshold < 1.0:
return (iou_scores >= threshold).type(x.dtype).mean()
else:
return iou_scores.mean()
def iou_scn_unmatch(x, gt, threshold=1.0):
"""
outputs: [K, H, W]
labels: [K, H, W]
"""
assert x.dim() == 3
assert gt.dim() == 3
len_x = x.size(0)
len_gt = gt.size(0)
M = []
for g in range(len_gt):
assign = []
for k in range(len_x):
# USE ONE OF THE METRIC AS OBJECTIVE TO MATCH pred & gt
assign.append(iou_pair(x[k], gt[g]))
M.append(torch.tensor(assign))
M = torch.stack(M, dim=0)
# pdb.set_trace()
iou_assign, m_indices = torch.max(M, dim=1)
if threshold < 1.0:
return (iou_assign >= threshold).type(x.dtype).mean(), m_indices.tolist()
else:
return iou_assign.mean(), m_indices.tolist()
def match_or_compute_segmentation_iou(m_preds, m_gts, num_comps, match_list=None, threshold=1.0):
""" If the 'match_list' is provided, then we return the mIoU scores. Otherwise, this function does Hungarian-style
matching and then return a match list.
:param m_preds: [B, V, K, H, W] tensor
:param m_gts: [B, V, K1, H, W] tensor
:param num_comps: [B, V] the number of objects in the scene (given as GT)
"""
assert m_preds.dim() == m_gts.dim()
K = m_preds.size(2)
B, V, K1, H, W = m_gts.size()
assert K >= K1, 'K is smaller than the number of objects, segmentation evaluations should be disabled'
assert num_comps.shape[0] == B
num_comps = num_comps.reshape((B * V,)).tolist()
m_preds = torch.zeros_like(m_preds).scatter(2, torch.max(m_preds, dim=2, keepdim=True)[1], 1)
m_preds = m_preds.reshape(B * V, K, H, W)
m_gts = m_gts.reshape(B * V, K1, H, W).type(m_preds.dtype)
N = m_preds.size(0)
assert m_gts.size(0) == N
iou_list = []
if match_list is None:
match_list = []
for i in range(N):
miou_scn_val, gt_to_out = iou_scn_unmatch(m_preds[i], m_gts[i, :num_comps[i]], threshold=threshold)
iou_list.append(numpify(miou_scn_val))
match_list.append(gt_to_out)
return iou_list, match_list
else:
for i in range(N):
miou_scn_val = iou_scn(m_preds[i, match_list[i]], m_gts[i, :num_comps[i]], threshold=threshold)
iou_list.append(numpify(miou_scn_val))
return iou_list, None
###############################
# Disentanglement related
###############################
shape_code = {
"_background_": [0, 0, 0, 0, 0, 0, 1],
"cube": [0, 0, 0, 0, 0, 1, 0],
"owl": [0, 0, 0, 0, 1, 0, 0],
"duck": [0, 0, 0, 1, 0, 0, 0],
"mug": [0, 0, 1, 0, 0, 0, 0],
"horse": [0, 1, 0, 0, 0, 0, 0],
"teapot": [1, 0, 0, 0, 0, 0, 0]
}
color_code = {
"gray": [87, 87, 87],
"red": [173, 35, 35],
"blue": [42, 75, 215],
"green": [29, 105, 20],
"brown": [129, 74, 25],
"purple": [129, 38, 192],
"cyan": [41, 208, 208],
"yellow": [255, 238, 51]
}
mat_code = {
'rubber': [0, 1],
'metal': [1, 0],
}
def save_latents_for_eval(z_v_out, z_out, scn_indices, qry_views, gt_scenes_meta,
out_dir=None,
save_count=0):
B = len(scn_indices)
GTs = []
for s_count, sid in enumerate(scn_indices):
# correct objects' permuations
obj_to_gt_idx = []
# for v in range(len(z_v_out)):
# obj_to_gt_idx.append(gt_scenes_meta[sid]['depth_orders'][v])
for v in qry_views:
obj_to_gt_idx.append(gt_scenes_meta[sid]['depth_orders'][v])
z_v_out = np.stack(z_v_out, axis=0)
actual_idx = save_count + s_count
base_dir_name = os.path.basename(out_dir)
save_to = None
if out_dir:
ensure_dir(out_dir)
save_to = os.path.join(out_dir, 'zout_{:06d}'.format(actual_idx))
np.savez(save_to,
out_z_v=z_v_out,
out_z=z_out[s_count])
s = {
'z_path': '{}/{}'.format(base_dir_name, os.path.basename(save_to)),
'obj_to_gt_map': obj_to_gt_idx,
'objects': gt_scenes_meta[sid]['objects'],
'query_views': qry_views
}
GTs.append(s)
return GTs