Skip to content

Commit

Permalink
XFeat parser for Mono and Stereo. (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeroo authored Sep 30, 2024
1 parent 6769aef commit de544a4
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 78 deletions.
6 changes: 4 additions & 2 deletions depthai_nodes/ml/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from .scrfd import SCRFDParser
from .segmentation import SegmentationParser
from .superanimal_landmarker import SuperAnimalParser
from .xfeat import XFeatParser
from .xfeat_mono import XFeatMonoParser
from .xfeat_stereo import XFeatStereoParser
from .yolo import YOLOExtendedParser
from .yunet import YuNetParser

Expand All @@ -28,7 +29,8 @@
"SuperAnimalParser",
"KeypointParser",
"MLSDParser",
"XFeatParser",
"XFeatMonoParser",
"XFeatStereoParser",
"ClassificationParser",
"YOLOExtendedParser",
"FastSAMParser",
Expand Down
133 changes: 61 additions & 72 deletions depthai_nodes/ml/parsers/utils/xfeat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, List, Tuple

import cv2
import numpy as np


Expand Down Expand Up @@ -59,26 +58,44 @@ def local_maximum_filter(x: np.ndarray, kernel_size: int) -> np.ndarray:
return local_max


def bilinear_grid_sample(
im: np.ndarray, grid: np.ndarray, align_corners: bool = False
) -> np.ndarray:
"""Bilinear grid sample.
def normgrid(x, H, W):
"""Normalize coords to [-1,1].
@param x: Input coordinates, shape (N, Hg, Wg, 2)
@type x: np.ndarray
@param H: Height of the output feature map
@type H: int
@param W: Width of the output feature map
@type W: int
@return: Normalized coordinates, shape (N, Hg, Wg, 2)
@rtype: np.ndarray
"""
return 2.0 * (x / np.array([W - 1, H - 1], dtype=x.dtype)) - 1.0

@param im: Input image tensor.

def bilinear(im, pos, H, W):
"""Given an input and a flow-field grid, computes the output using input values and
pixel locations from grid. Supported only bilinear interpolation method to sample
the input pixels.
@param im: Input feature map, shape (N, C, H, W)
@type im: np.ndarray
@param grid: Grid tensor.
@type grid: np.ndarray
@param align_corners: Whether to align corners.
@type align_corners: bool
@return: Output image tensor after applying bilinear grid sample.
@param pos: Point coordinates, shape (N, Hg, Wg, 2)
@type pos: np.ndarray
@param H: Height of the output feature map
@type H: int
@param W: Width of the output feature map
@type W: int
@return: A tensor with sampled points, shape (N, C, Hg, Wg)
@rtype: np.ndarray
"""
align_corners = False
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn
grid = normgrid(pos, H, W)[..., np.newaxis]
grid = grid.transpose(0, 1, 3, 2)

x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
x = grid[..., 0]
y = grid[..., 1]

if align_corners:
x = ((x + 1) / 2) * (w - 1)
Expand All @@ -95,43 +112,44 @@ def bilinear_grid_sample(
x1 = x0 + 1
y1 = y0 + 1

wa = ((x1 - x) * (y1 - y)).reshape(n, 1, -1)
wb = ((x1 - x) * (y - y0)).reshape(n, 1, -1)
wc = ((x - x0) * (y1 - y)).reshape(n, 1, -1)
wd = ((x - x0) * (y - y0)).reshape(n, 1, -1)
wa = ((x1 - x) * (y1 - y))[:, np.newaxis]
wb = ((x1 - x) * (y - y0))[:, np.newaxis]
wc = ((x - x0) * (y1 - y))[:, np.newaxis]
wd = ((x - x0) * (y - y0))[:, np.newaxis]

# Apply padding
im_padded = np.pad(
im, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="constant", constant_values=0
)
padded_h = h + 2
padded_w = w + 2
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

# Clip coordinates to padded image size
# Adjust points for padding
x0, x1 = x0 + 1, x1 + 1
y0, y1 = y0 + 1, y1 + 1

# Clip coordinates to stay within bounds
x0 = np.clip(x0, 0, padded_w - 1)
x1 = np.clip(x1, 0, padded_w - 1)
y0 = np.clip(y0, 0, padded_h - 1)
y1 = np.clip(y1, 0, padded_h - 1)

# Flatten im_padded for indexing
im_padded = im_padded.reshape(n, c, -1)

x0_y0 = (x0 + y0 * padded_w).reshape(n, 1, -1)
x0_y1 = (x0 + y1 * padded_w).reshape(n, 1, -1)
x1_y0 = (x1 + y0 * padded_w).reshape(n, 1, -1)
x1_y1 = (x1 + y1 * padded_w).reshape(n, 1, -1)
x0_y0 = (x0 + y0 * padded_w)[:, np.newaxis].repeat(c, axis=1)
x0_y1 = (x0 + y1 * padded_w)[:, np.newaxis].repeat(c, axis=1)
x1_y0 = (x1 + y0 * padded_w)[:, np.newaxis].repeat(c, axis=1)
x1_y1 = (x1 + y1 * padded_w)[:, np.newaxis].repeat(c, axis=1)

def gather(im_padded, idx):
idx = idx.astype(np.int32)
gathered = np.take_along_axis(im_padded, idx, axis=2)
return gathered
Ia = np.take_along_axis(im_padded, x0_y0, axis=2)
Ib = np.take_along_axis(im_padded, x0_y1, axis=2)
Ic = np.take_along_axis(im_padded, x1_y0, axis=2)
Id = np.take_along_axis(im_padded, x1_y1, axis=2)

Ia = gather(im_padded, x0_y0)
Ib = gather(im_padded, x0_y1)
Ic = gather(im_padded, x1_y0)
Id = gather(im_padded, x1_y1)

result = (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
result = (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(
n, c, grid.shape[1], grid.shape[2]
)
return result


Expand Down Expand Up @@ -192,6 +210,7 @@ def _nms(
def detect_and_compute(
feats: np.ndarray,
kpts: np.ndarray,
heatmaps: np.ndarray,
resize_rate_w: float,
resize_rate_h: float,
input_size: Tuple[int, int],
Expand Down Expand Up @@ -220,37 +239,12 @@ def detect_and_compute(
kpts_heats = _get_kpts_heatmap(kpts)
mkpts = _nms(kpts_heats, threshold=0.05, kernel_size=5) # int64

# Numpy implementation of normgrid
div_array = np.array([input_size[0] - 1, input_size[1] - 1], dtype=mkpts.dtype)
grid = 2.0 * (mkpts / div_array) - 1.0
grid = np.expand_dims(grid, axis=2)

if grid.size == 0:
if mkpts.size == 0:
return None

# Numpy implementation of F.grid_sample
map_x = grid[..., 0].reshape(-1).astype(np.float32)
map_y = grid[..., 1].reshape(-1).astype(np.float32)
remapped = cv2.remap(
kpts_heats[0, 0],
map_x,
map_y,
interpolation=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0,
)
nearest_result = np.expand_dims(remapped, axis=0)
nearest_result = bilinear(kpts_heats, mkpts, input_size[1], input_size[0])
bilinear_result = bilinear(heatmaps, mkpts, input_size[1], input_size[0])

# Numpy implementation of F.grid_sample
remapped = cv2.remap(
kpts_heats[0, 0],
map_x,
map_y,
interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0,
)
bilinear_result = np.expand_dims(remapped, axis=0)
scores = (nearest_result * bilinear_result).reshape(1, -1)

scores = scores.astype(np.float32)
Expand All @@ -262,17 +256,12 @@ def detect_and_compute(
mkpts_x = np.take_along_axis(mkpts[..., 0], idxs, axis=-1)[:, :top_k]
mkpts_y = np.take_along_axis(mkpts[..., 1], idxs, axis=-1)[:, :top_k]
mkpts = np.stack([mkpts_x, mkpts_y], axis=-1)
scores = np.take_along_axis(scores, idxs, axis=-1)[:, :top_k]

div_array = np.array([input_size[0] - 1, input_size[1] - 1], dtype=mkpts.dtype)
grid = 2.0 * (mkpts / div_array) - 1.0
grid = np.expand_dims(grid, axis=2)
map_x = grid[..., 0].reshape(-1).astype(np.float32)
map_y = grid[..., 1].reshape(-1).astype(np.float32)
mkpts = mkpts.astype(np.float32)

feats = bilinear_grid_sample(feats, grid, align_corners=False)
feats = feats.transpose(0, 2, 3, 1).squeeze(-2)
scores = np.take_along_axis(scores, idxs, axis=-1)[:, :top_k]

feats = bilinear(feats, mkpts, input_size[1], input_size[0])
feats = feats[0].transpose(2, 1, 0)

norm = np.linalg.norm(feats, axis=-1, keepdims=True)
feats = feats / norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from .utils.xfeat import detect_and_compute, match


class XFeatParser(dai.node.ThreadedHostNode):
"""Parser class for parsing the output of the XFeat model.
class XFeatMonoParser(dai.node.ThreadedHostNode):
"""Parser class for parsing the output of the XFeat model. It can be used for
parsing the output from one source (e.g. one camera). The reference frame can be set
with trigger method.
Attributes
----------
Expand All @@ -24,6 +26,8 @@ class XFeatParser(dai.node.ThreadedHostNode):
Maximum number of keypoints to keep.
previous_results : np.ndarray
Previous results from the model. Previous results are used to match keypoints between two frames.
trigger : bool
Trigger to set the reference frame.
Output Message/s
----------------
Expand All @@ -48,6 +52,8 @@ def __init__(
@type original_size: Tuple[float, float]
@param input_size: Input image size.
@type input_size: Tuple[float, float]
@param max_keypoints: Maximum number of keypoints to keep.
@type max_keypoints: int
"""
dai.node.ThreadedHostNode.__init__(self)
self.input = self.createInput()
Expand All @@ -56,6 +62,7 @@ def __init__(
self.input_size = input_size
self.max_keypoints = max_keypoints
self.previous_results = None
self.trigger = False

def setOriginalSize(self, original_size):
"""Sets the original image size.
Expand All @@ -81,6 +88,10 @@ def setMaxKeypoints(self, max_keypoints):
"""
self.max_keypoints = max_keypoints

def setTrigger(self):
"""Sets the trigger to set the reference frame."""
self.trigger = True

def run(self):
if self.original_size is None:
raise ValueError("Original image size must be specified!")
Expand All @@ -98,17 +109,21 @@ def run(self):
keypoints = output.getTensor("keypoints", dequantize=True).astype(
np.float32
)
heatmaps = output.getTensor("heatmaps", dequantize=True).astype(np.float32)

if len(feats.shape) == 3:
feats = feats.reshape((1,) + feats.shape).transpose(0, 3, 1, 2)
if len(keypoints.shape) == 3:
keypoints = keypoints.reshape((1,) + keypoints.shape).transpose(
0, 3, 1, 2
)
if len(heatmaps.shape) == 3:
heatmaps = heatmaps.reshape((1,) + heatmaps.shape).transpose(0, 3, 1, 2)

result = detect_and_compute(
feats,
keypoints,
heatmaps,
resize_rate_w,
resize_rate_h,
self.input_size,
Expand All @@ -128,6 +143,11 @@ def run(self):
matched_points = create_tracked_features_message(mkpts0, mkpts1)
matched_points.setTimestamp(output.getTimestamp())
self.out.send(matched_points)
else:
matched_points = dai.TrackedFeatures()
matched_points.setTimestamp(output.getTimestamp())
self.out.send(matched_points)

# save the result from first frame
self.previous_results = result
if self.trigger:
self.previous_results = result
self.trigger = False
Loading

0 comments on commit de544a4

Please sign in to comment.