Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
harrydobbs committed Oct 16, 2024
1 parent cdabdd9 commit 509236f
Show file tree
Hide file tree
Showing 6 changed files with 748 additions and 0 deletions.
120 changes: 120 additions & 0 deletions tests/view_cylinder_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch
import numpy as np
import open3d as o3d
from torch_ransac3d.cylinder import (
cylinder_fit,
estimate_normals,
) # Assuming the previous code is in cylinder_fit.py


def generate_cylinder_points(n_points, center, axis, radius, height, noise_level=0.05):
"""Generate synthetic points on a cylinder surface with noise."""
# Generate random heights along the cylinder axis
h = np.random.uniform(0, height, n_points)

# Generate random angles around the cylinder
theta = np.random.uniform(0, 2 * np.pi, n_points)

# Calculate points on the cylinder surface
x = radius * np.cos(theta)
y = radius * np.sin(theta)

# Rotate the cylinder to align with the specified axis
rotation_axis = np.cross([0, 0, 1], axis)
rotation_angle = np.arccos(np.dot([0, 0, 1], axis))
rotation_matrix = o3d.geometry.get_rotation_matrix_from_axis_angle(
rotation_axis * rotation_angle
)

points = np.dot(rotation_matrix, np.vstack((x, y, h)))
points = points.T + center

# Add noise
noise = np.random.normal(0, noise_level, points.shape)
points += noise

return torch.tensor(points, dtype=torch.float32)


def create_cylinder_mesh(center, axis, radius, height, resolution=50):
"""Create an Open3D cylinder mesh."""
cylinder = o3d.geometry.TriangleMesh.create_cylinder(
radius=radius, height=height, resolution=resolution
)

# Rotate cylinder to align with axis
rotation_matrix = o3d.geometry.get_rotation_matrix_from_axis_angle(
np.cross([0, 0, 1], axis) * np.arccos(np.dot([0, 0, 1], axis))
)
cylinder.rotate(rotation_matrix, center=True)

# Move cylinder to center
cylinder.translate(center)

return cylinder


def visualize_cylinder_fit(points, center, axis, radius, inliers):
# Create Open3D point cloud for all points
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points.numpy())

# Color all points blue
colors = np.zeros_like(points.numpy()) + [0, 0, 1] # Blue

# Color inliers red
colors[inliers.numpy()] = [1, 0, 0] # Red
pcd.colors = o3d.utility.Vector3dVector(colors)

# Create cylinder mesh
height = np.max(points.numpy(), axis=0) - np.min(points.numpy(), axis=0)
height = np.linalg.norm(height)
cylinder_mesh = create_cylinder_mesh(
center.numpy(), axis.numpy(), radius.item(), height
)
cylinder_mesh.paint_uniform_color([0, 1, 0]) # Green

# Create coordinate frame
coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
size=1.0, origin=center.numpy()
)

# Visualize
o3d.visualization.draw_geometries([pcd, cylinder_mesh, coord_frame])


def main():
# Generate synthetic cylinder data
n_points = 1000
true_center = np.array([1.0, 2.0, 3.0]) # Explicitly use floating-point values
true_axis = np.array([1.0, 1.0, 1.0]) # Explicitly use floating-point values
true_axis /= np.linalg.norm(true_axis)
true_radius = 2.0
height = 10.0
noise_level = 0.1

points = generate_cylinder_points(
n_points, true_center, true_axis, true_radius, height, noise_level
)

# Estimate normals
normals = estimate_normals(points)

# Fit cylinder
center, axis, radius, inliers = cylinder_fit(points, normals)

# Print results
print(f"True center: {true_center}")
print(f"Fitted center: {center.numpy()}")
print(f"True axis: {true_axis}")
print(f"Fitted axis: {axis.numpy()}")
print(f"True radius: {true_radius}")
print(f"Fitted radius: {radius.item()}")
print(f"Number of inliers: {len(inliers)}")

# Visualize results
visualize_cylinder_fit(points, center, axis, radius, inliers)


if __name__ == "__main__":
main()
142 changes: 142 additions & 0 deletions torch_ransac3d/circle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import torch
from typing import Tuple
from .wrapper import numpy_to_torch
from .util import rodrigues_rot_torch


@numpy_to_torch
@torch.compile
@torch.no_grad()
def circle_fit(
pts: torch.Tensor,
thresh: float = 0.2,
max_iterations: int = 1000,
iterations_per_batch: int = 1,
epsilon: float = 1e-8,
device: torch.device = torch.device("cpu"),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Find the parameters (center, axis, and radius) to define a circle using a batched RANSAC approach.
This function fits a circle to a 3D point cloud using a RANSAC-like method,
processing multiple iterations in parallel for efficiency.
:param pts: 3D point cloud.
:type pts: torch.Tensor
:param thresh: Threshold distance from the circle which is considered inlier.
:type thresh: float
:param max_iterations: Maximum number of iterations for the RANSAC algorithm.
:type max_iterations: int
:param iterations_per_batch: Number of iterations to process in parallel.
:type iterations_per_batch: int
:param epsilon: Small value to avoid division by zero.
:type epsilon: float
:param device: Device to run the computations on.
:type device: torch.device
:return: A tuple containing:
- center (torch.Tensor): Center of the circle (shape: (3,))
- axis (torch.Tensor): Vector describing circle's plane normal (shape: (3,))
- radius (torch.Tensor): Radius of the circle (shape: (1,))
- inliers (torch.Tensor): Indices of points from the dataset considered as inliers
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Example:
>>> pts = torch.randn(1000, 3)
>>> center, axis, radius, inliers = circle_fit(pts)
>>> print(f"Circle center: {center}")
>>> print(f"Circle axis: {axis}")
>>> print(f"Circle radius: {radius}")
>>> print(f"Number of inliers: {inliers.shape[0]}")
"""
pts = pts.to(device)
num_pts = pts.shape[0]

best_inliers = torch.tensor([], dtype=torch.long, device=device)
best_center = torch.zeros(3, device=device)
best_axis = torch.zeros(3, device=device)
best_radius = torch.tensor(0.0, device=device)

for start_idx in range(0, max_iterations, iterations_per_batch):
end_idx = min(start_idx + iterations_per_batch, max_iterations)
current_batch_size = end_idx - start_idx

# Sample 3 random points for each iteration in the batch
rand_pt_idx = torch.randint(0, num_pts, (current_batch_size, 3), device=device)
pt_samples = pts[rand_pt_idx] # (batch_size, 3, 3)

# Compute vectors for the plane
vec_A = pt_samples[:, 1] - pt_samples[:, 0]
vec_A_norm = vec_A / (torch.norm(vec_A, dim=1, keepdim=True) + epsilon)
vec_B = pt_samples[:, 2] - pt_samples[:, 0]
vec_B_norm = vec_B / (torch.norm(vec_B, dim=1, keepdim=True) + epsilon)

# Compute normal vector to the plane
vec_C = torch.cross(vec_A_norm, vec_B_norm)
vec_C = vec_C / (torch.norm(vec_C, dim=1, keepdim=True) + epsilon)

# Compute plane equation
k = -torch.sum(vec_C * pt_samples[:, 1], dim=1)
plane_eq = torch.cat([vec_C, k.unsqueeze(1)], dim=1) # (batch_size, 4)

# Rotate points to align with z-axis
P_rot = rodrigues_rot_torch(
pt_samples, vec_C, torch.tensor([0, 0, 1], device=device)
)

# Find center from 3 points
ma = (P_rot[:, 1, 1] - P_rot[:, 0, 1]) / (
P_rot[:, 1, 0] - P_rot[:, 0, 0] + epsilon
)
mb = (P_rot[:, 2, 1] - P_rot[:, 1, 1]) / (
P_rot[:, 2, 0] - P_rot[:, 1, 0] + epsilon
)

p_center_x = (
ma * mb * (P_rot[:, 0, 1] - P_rot[:, 2, 1])
+ mb * (P_rot[:, 0, 0] + P_rot[:, 1, 0])
- ma * (P_rot[:, 1, 0] + P_rot[:, 2, 0])
) / (2 * (mb - ma + epsilon))
p_center_y = (
-1 / (ma + epsilon) * (p_center_x - (P_rot[:, 0, 0] + P_rot[:, 1, 0]) / 2)
+ (P_rot[:, 0, 1] + P_rot[:, 1, 1]) / 2
)
p_center = torch.stack(
[p_center_x, p_center_y, torch.zeros_like(p_center_x)], dim=1
)
radius = torch.norm(p_center - P_rot[:, 0, :], dim=1)

# Rotate center back to original orientation
center = rodrigues_rot_torch(
p_center.unsqueeze(1), torch.tensor([0, 0, 1], device=device), vec_C
).squeeze(1)

# Compute distances from points to circle
dist_pt_plane = torch.abs(
torch.sum(pts * plane_eq[:, :3].unsqueeze(1), dim=2)
+ plane_eq[:, 3].unsqueeze(1)
) / (torch.norm(plane_eq[:, :3], dim=1, keepdim=True) + epsilon)
dist_pt_inf_circle = torch.norm(
torch.cross(
vec_C.unsqueeze(1).expand(-1, num_pts, -1),
(center.unsqueeze(1) - pts.unsqueeze(0)),
),
dim=2,
) - radius.unsqueeze(1)
dist_pt = torch.sqrt(dist_pt_inf_circle**2 + dist_pt_plane**2)

# Select inliers
inlier_mask = dist_pt <= thresh
inlier_counts = inlier_mask.sum(dim=1)

# Find the best iteration in this batch
best_in_batch_idx = torch.argmax(inlier_counts)
best_inlier_count_in_batch = inlier_counts[best_in_batch_idx].item()

if best_inlier_count_in_batch > best_inliers.shape[0]:
best_inliers = torch.where(inlier_mask[best_in_batch_idx])[0]
best_center = center[best_in_batch_idx]
best_axis = vec_C[best_in_batch_idx]
best_radius = radius[best_in_batch_idx]

return best_center, best_axis, best_radius, best_inliers
120 changes: 120 additions & 0 deletions torch_ransac3d/cuboid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch
from typing import Tuple
from .wrapper import numpy_to_torch


@numpy_to_torch
@torch.compile
@torch.no_grad()
def cuboid_fit(
pts: torch.Tensor,
thresh: float = 0.05,
max_iterations: int = 5000,
iterations_per_batch: int = 1,
epsilon: float = 1e-8,
device: torch.device = torch.device("cpu"),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find the best equations for 3 planes which define a complete cuboid using a batched RANSAC approach.
This function fits a cuboid to a 3D point cloud using a RANSAC-like method,
processing multiple iterations in parallel for efficiency.
:param pts: 3D point cloud.
:type pts: torch.Tensor
:param thresh: Threshold distance from the plane which is considered inlier.
:type thresh: float
:param max_iterations: Maximum number of iterations for the RANSAC algorithm.
:type max_iterations: int
:param iterations_per_batch: Number of iterations to process in parallel.
:type iterations_per_batch: int
:param epsilon: Small value to avoid division by zero.
:type epsilon: float
:param device: Device to run the computations on.
:type device: torch.device
:return: A tuple containing:
- best_eq (torch.Tensor): Array of 3 best planes' equations (shape: (3, 4))
- best_inliers (torch.Tensor): Indices of points from the dataset considered as inliers
:rtype: Tuple[torch.Tensor, torch.Tensor]
Example:
>>> pts = torch.randn(1000, 3)
>>> equations, inliers = cuboid_fit(pts)
>>> print(f"Cuboid equations:\n{equations}")
>>> print(f"Number of inliers: {inliers.shape[0]}")
"""
pts = pts.to(device)
num_pts = pts.shape[0]

best_inliers = torch.tensor([], dtype=torch.long, device=device)
best_eq = torch.zeros((3, 4), device=device)

for start_idx in range(0, max_iterations, iterations_per_batch):
end_idx = min(start_idx + iterations_per_batch, max_iterations)
current_batch_size = end_idx - start_idx

# Sample 6 random points for each iteration in the batch
rand_pt_idx = torch.randint(0, num_pts, (current_batch_size, 6), device=device)
pt_samples = pts[rand_pt_idx] # (batch_size, 6, 3)

# Compute vectors for the first plane
vec_A = pt_samples[:, 1] - pt_samples[:, 0]
vec_B = pt_samples[:, 2] - pt_samples[:, 0]
vec_C = torch.cross(vec_A, vec_B)
vec_C = vec_C / (torch.norm(vec_C, dim=1, keepdim=True) + epsilon)

# Compute k for the first plane
k = -torch.sum(vec_C * pt_samples[:, 1], dim=1)
plane_eq = torch.cat([vec_C, k.unsqueeze(1)], dim=1) # (batch_size, 4)

# Compute the second plane
dist_p4_plane = (
torch.sum(plane_eq[:, :3] * pt_samples[:, 3], dim=1) + plane_eq[:, 3]
)
dist_p4_plane = dist_p4_plane / (torch.norm(plane_eq[:, :3], dim=1) + epsilon)
p4_proj_plane = pt_samples[:, 3] - dist_p4_plane.unsqueeze(1) * vec_C

vec_D = p4_proj_plane - pt_samples[:, 3]
vec_E = pt_samples[:, 4] - pt_samples[:, 3]
vec_F = torch.cross(vec_D, vec_E)
vec_F = vec_F / (torch.norm(vec_F, dim=1, keepdim=True) + epsilon)

k = -torch.sum(vec_F * pt_samples[:, 4], dim=1)
plane_eq = torch.cat(
[plane_eq, torch.cat([vec_F, k.unsqueeze(1)], dim=1)], dim=1
) # (batch_size, 8)

# Compute the third plane
vec_G = torch.cross(vec_C, vec_F)
k = -torch.sum(vec_G * pt_samples[:, 5], dim=1)
plane_eq = torch.cat(
[plane_eq, torch.cat([vec_G, k.unsqueeze(1)], dim=1)], dim=1
) # (batch_size, 12)

# Reshape plane_eq to (batch_size, 3, 4)
plane_eq = plane_eq.view(current_batch_size, 3, 4)

# Compute distances of all points to each plane in the batch
dist_pt = torch.abs(
torch.einsum("bij,kj->bik", plane_eq[:, :, :3], pts)
+ plane_eq[:, :, 3].unsqueeze(2)
)
dist_pt = dist_pt / (
torch.norm(plane_eq[:, :, :3], dim=2, keepdim=True) + epsilon
)

# Select inliers
min_dist_pt = torch.min(dist_pt, dim=1)[0]
inlier_mask = min_dist_pt <= thresh
inlier_counts = inlier_mask.sum(dim=1)

# Find the best iteration in this batch
best_in_batch_idx = torch.argmax(inlier_counts)
best_inlier_count_in_batch = inlier_counts[best_in_batch_idx].item()

if best_inlier_count_in_batch > best_inliers.shape[0]:
best_inliers = torch.where(inlier_mask[best_in_batch_idx])[0]
best_eq = plane_eq[best_in_batch_idx]

return best_eq, best_inliers
Loading

0 comments on commit 509236f

Please sign in to comment.