Skip to content

Commit b01e129

Browse files
committed
Cleanup layout postprocessing
1 parent 6f8f17f commit b01e129

File tree

3 files changed

+46
-83
lines changed

3 files changed

+46
-83
lines changed

detect_layout.py

-23
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,6 @@ def main():
4141
result_path = os.path.join(args.results_dir, folder_name)
4242
os.makedirs(result_path, exist_ok=True)
4343

44-
for idx, (layout_pred, line_pred, name) in enumerate(zip(layout_predictions, line_predictions, names)):
45-
blocks = layout_pred.bboxes
46-
for line in line_pred.vertical_lines:
47-
new_blocks = []
48-
for block in blocks:
49-
block_modified = False
50-
51-
if line.bbox[0] > block.bbox[0] and line.bbox[2] < block.bbox[2]:
52-
overlap_pct = (min(line.bbox[3], block.bbox[3]) - max(line.bbox[1], block.bbox[1])) / (
53-
block.bbox[3] - block.bbox[1])
54-
if overlap_pct > 0.5:
55-
block1 = copy.deepcopy(block)
56-
block2 = copy.deepcopy(block)
57-
block1.bbox[2] = line.bbox[0]
58-
block2.bbox[0] = line.bbox[2]
59-
new_blocks.append(block1)
60-
new_blocks.append(block2)
61-
block_modified = True
62-
if not block_modified:
63-
new_blocks.append(block)
64-
blocks = new_blocks
65-
layout_pred.bboxes = blocks
66-
6744
if args.images:
6845
for idx, (image, layout_pred, line_pred, name) in enumerate(zip(images, layout_predictions, line_predictions, names)):
6946
polygons = [p.polygon for p in layout_pred.bboxes]

surya/layout.py

+45-58
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from collections import defaultdict
23
from typing import List, Optional
34
from PIL import Image
45
import numpy as np
@@ -28,7 +29,7 @@ def bbox_avg(integral_image, x1, y1, x2, y2):
2829
def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[Image.Image], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
2930
logits = np.stack(heatmaps, axis=0)
3031
vertical_line_bboxes = [line for line in detection_result.vertical_lines]
31-
line_bboxes = [line for line in detection_result.bboxes]
32+
line_bboxes = detection_result.bboxes
3233

3334
# Scale back to processor size
3435
for line in vertical_line_bboxes:
@@ -51,66 +52,57 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
5152
logits[i, segment_assignment != i] = 0
5253

5354
detected_boxes = []
54-
done_maps = set()
55-
for iteration in range(100): # detect up to 100 boxes
56-
bbox = None
57-
confidence = None
58-
for heatmap_idx in range(1, len(id2label)): # Skip the blank class
59-
if heatmap_idx in done_maps:
55+
for heatmap_idx in range(1, len(id2label)): # Skip the blank class
56+
heatmap = logits[heatmap_idx]
57+
bboxes = get_detected_boxes(heatmap, text_threshold=.9, low_text=.8)
58+
bboxes = [bbox for bbox in bboxes if bbox.area > 25]
59+
for bb in bboxes:
60+
bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1])
61+
62+
integral_image = compute_integral_image(heatmap)
63+
bbox_confidences = [bbox_avg(integral_image, *[int(b) for b in bbox.bbox]) for bbox in bboxes]
64+
for confidence, bbox in zip(bbox_confidences, bboxes):
65+
if confidence <= .3:
6066
continue
61-
heatmap = logits[heatmap_idx]
62-
bboxes = get_detected_boxes(heatmap, text_threshold=.9)
63-
bboxes = [bbox for bbox in bboxes if bbox.area > 25]
64-
for bb in bboxes:
65-
bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1])
66-
67-
if len(bboxes) == 0:
68-
done_maps.add(heatmap_idx)
69-
continue
70-
71-
integral_image = compute_integral_image(heatmap)
72-
bbox_confidences = [bbox_avg(integral_image, *[int(b) for b in bbox.bbox]) for bbox in bboxes]
73-
74-
max_confidence = max(bbox_confidences)
75-
max_confidence_idx = bbox_confidences.index(max_confidence)
76-
if max_confidence >= .15 and (confidence is None or max_confidence > confidence):
77-
bbox = LayoutBox(polygon=bboxes[max_confidence_idx].polygon, label=id2label[heatmap_idx])
78-
elif max_confidence < .15:
79-
done_maps.add(heatmap_idx)
80-
81-
if bbox is None:
82-
break
83-
84-
# Expand bbox to cover intersecting lines
85-
remove_indices = []
86-
covered_lines = []
67+
bbox = LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=confidence)
68+
detected_boxes.append(bbox)
69+
70+
detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True)
71+
# Expand bbox to cover intersecting lines
72+
box_lines = defaultdict(list)
73+
used_lines = set()
74+
for bbox_idx, bbox in enumerate(detected_boxes):
8775
for line_idx, line_bbox in enumerate(line_bboxes):
88-
if line_bbox.intersection_pct(bbox) >= .5:
89-
remove_indices.append(line_idx)
90-
covered_lines.append(line_bbox.bbox)
76+
if line_bbox.intersection_pct(bbox) >= .5 and line_idx not in used_lines:
77+
box_lines[bbox_idx].append(line_bbox.bbox)
78+
used_lines.add(line_idx)
9179

92-
logits[:, int(bbox.bbox[1]):int(bbox.bbox[3]), int(bbox.bbox[0]):int(bbox.bbox[2])] = 0 # zero out where the detected bbox is
93-
if len(covered_lines) == 0 and bbox.label not in ["Picture", "Formula"]:
80+
new_boxes = []
81+
for bbox_idx, bbox in enumerate(detected_boxes):
82+
if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]:
9483
continue
9584

96-
if len(covered_lines) > 0 and bbox.label == "Picture":
85+
if bbox_idx in box_lines and bbox.label in ["Picture"]:
9786
bbox.label = "Figure"
9887

88+
covered_lines = box_lines[bbox_idx]
9989
if len(covered_lines) > 0 and bbox.label not in ["Picture"]:
10090
min_x = min([line[0] for line in covered_lines])
10191
min_y = min([line[1] for line in covered_lines])
10292
max_x = max([line[2] for line in covered_lines])
10393
max_y = max([line[3] for line in covered_lines])
10494

105-
min_x_box = min([b[0] for b in bbox.polygon])
106-
min_y_box = min([b[1] for b in bbox.polygon])
107-
max_x_box = max([b[0] for b in bbox.polygon])
108-
max_y_box = max([b[1] for b in bbox.polygon])
95+
if bbox.label in ["Figure", "Table", "Formula"]:
96+
# Figures can tables can contain text, but text isn't the whole area
97+
min_x_box = min([b[0] for b in bbox.polygon])
98+
min_y_box = min([b[1] for b in bbox.polygon])
99+
max_x_box = max([b[0] for b in bbox.polygon])
100+
max_y_box = max([b[1] for b in bbox.polygon])
109101

110-
min_x = min(min_x, min_x_box)
111-
min_y = min(min_y, min_y_box)
112-
max_x = max(max_x, max_x_box)
113-
max_y = max(max_y, max_y_box)
102+
min_x = min(min_x, min_x_box)
103+
min_y = min(min_y, min_y_box)
104+
max_x = max(max_x, max_x_box)
105+
max_y = max(max_y, max_y_box)
114106

115107
bbox.polygon[0][0] = min_x
116108
bbox.polygon[0][1] = min_y
@@ -121,21 +113,16 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
121113
bbox.polygon[3][0] = min_x
122114
bbox.polygon[3][1] = max_y
123115

124-
# Remove "used" overlap lines
125-
line_bboxes = [line_bboxes[i] for i in range(len(line_bboxes)) if i not in remove_indices]
126-
detected_boxes.append(bbox)
127-
128-
logits[:, int(bbox.bbox[1]):int(bbox.bbox[3]), int(bbox.bbox[0]):int(bbox.bbox[2])] = 0 # zero out where the new box is
116+
new_boxes.append(bbox)
129117

130-
if len(line_bboxes) > 0:
131-
for bbox in line_bboxes:
132-
detected_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text"))
118+
unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines]
119+
for bbox in unused_lines:
120+
new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5))
133121

134-
for bbox in detected_boxes:
122+
for bbox in new_boxes:
135123
bbox.rescale(list(reversed(heatmap.shape)), orig_size)
136124

137-
detected_boxes = [bbox for bbox in detected_boxes if bbox.area > 16]
138-
detected_boxes = clean_contained_boxes(detected_boxes)
125+
detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16]
139126
return detected_boxes
140127

141128

surya/postprocessing/heatmap.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ def keep_largest_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
2626
if box == other_box:
2727
continue
2828
# find overlap percentage
29-
overlap = max(0, min(box[2], other_box[2]) - max(box[0], other_box[0])) * max(0, min(box[3], other_box[3]) - max(box[1], other_box[1]))
30-
overlap = overlap / box_area
29+
overlap = box_obj.intersection_pct(other_box_obj)
3130
if overlap > .9 and box_area < other_box_area:
3231
contained = True
3332
break

0 commit comments

Comments
 (0)