forked from luigifreda/pyslam
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_aslfeat.py
130 lines (101 loc) · 4.58 KB
/
feature_aslfeat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Copyright 2020, Zixin Luo, HKUST.
Image matching example.
"""
import config
config.cfg.set_lib('ASLFeat',prepend=True)
from threading import RLock
import warnings # to disable tensorflow-numpy warnings: from https://github.com/tensorflow/tensorflow/issues/30427
warnings.filterwarnings('ignore', category=FutureWarning)
import os
import cv2
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES'] = ""
if False:
import tensorflow as tf
else:
# from https://stackoverflow.com/questions/56820327/the-name-tf-session-is-deprecated-please-use-tf-compat-v1-session-instead
import tensorflow.compat.v1 as tf
from ASLFeat.utils.opencvhelper import MatcherWrapper
from ASLFeat.models.feat_model import FeatModel
# from ASLFeat.models import get_model
from utils_tf import set_tf_logging
kVerbose = True
# interface for pySLAM
class ASLFeature2D:
def __init__(self,
num_features=2000,
model_type='ckpt',
do_tf_logging=False):
print('Using ASLFeat')
self.lock = RLock()
self.model_base_path= config.cfg.root_folder + '/thirdparty/ASLFeat/'
set_tf_logging(do_tf_logging)
self.num_features = num_features
self.model_type = model_type
self.model_path = self.model_base_path + 'pretrained/aslfeatv2'
# if self.model_type == 'pb':
# self.model_path = os.path.join(self.model_path, 'aslfeat.pb')
if self.model_type == 'ckpt':
self.model_path = os.path.join(self.model_path, 'model.ckpt-60000')
else:
print("Model not found at path {}".format(self.model_path))
raise NotImplementedError
self.keypoint_size = 10 # just a representative size for visualization and in order to convert extracted points to cv2.KeyPoint
self.kps = []
self.des = []
self.frame = None
print('==> Loading pre-trained network.')
config_loc = {'max_dim': 2048,
'config':{
'kpt_n': self.num_features,
'kpt_refinement': True,
'deform_desc': 1,
'score_thld': 0.5,
'edge_thld': 10,
'multi_scale': True,
'multi_level': True,
'nms_size': 3,
'eof_mask': 5,
'need_norm': True,
'use_peakiness': True}}
self.model = FeatModel(self.model_path, **config_loc)
# self.model = get_model('feat_model')(self.model_path, **config_loc)
print('==> Successfully loaded pre-trained network.')
def __del__(self):
with self.lock:
self.model.close()
def prep_img(self,img):
rgb_list = []
gray_list = []
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., np.newaxis]
img = img[..., ::-1]
rgb_list.append(img)
gray_list.append(gray)
return rgb_list, gray_list
def compute_kps_des(self, frame):
with self.lock:
rgb_list, gray_list = self.prep_img(frame)
# extract features.
des, kps, _ = self.model.run_test_data(gray_list[0])
return kps, des
def detectAndCompute(self, frame, mask=None): #mask is a fake input
with self.lock:
self.frame = frame
self.kps, self.des = self.compute_kps_des(frame)
if kVerbose:
print('detector: ASLFeat, descriptor: ASLFeat, #features: ', len(self.kps), ', frame res: ', frame.shape[0:2])
return self.kps, self.des
# return keypoints if available otherwise call detectAndCompute()
def detect(self, frame, mask=None): # mask is a fake input
with self.lock:
if self.frame is not frame:
self.detectAndCompute(frame)
return self.kps
# return descriptors if available otherwise call detectAndCompute()
def compute(self, frame, kps=None, mask=None): # kps is a fake input, mask is a fake input
with self.lock:
if self.frame is not frame:
#Printer.orange('WARNING: ASLFeat is recomputing both kps and des on last input frame', frame.shape)
self.detectAndCompute(frame)
return self.kps, self.des