Skip to content

Commit 51c4c5f

Browse files
committed
Add slicing logic
1 parent e790f50 commit 51c4c5f

File tree

4 files changed

+180
-8
lines changed

4 files changed

+180
-8
lines changed

surya/input/slicing.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import math
2+
from typing import List, Tuple
3+
4+
import cv2
5+
import numpy as np
6+
from PIL import Image
7+
8+
from surya.schema import LayoutResult
9+
10+
SLICES_TYPE = Tuple[List[Image.Image], List[Tuple[int, int, int]]]
11+
12+
13+
class ImageSlicer:
14+
merge_tolerance = .05
15+
16+
def __init__(self, slice_min_dims, max_slices=4):
17+
self.slice_min_dims = slice_min_dims
18+
self.max_slices = max_slices
19+
20+
def slice(self, images: List[Image.Image]) -> SLICES_TYPE:
21+
all_slices = []
22+
all_positions = []
23+
24+
for idx, image in enumerate(images):
25+
if (image.size[0] > self.slice_min_dims["width"] or
26+
image.size[1] > self.slice_min_dims["height"]):
27+
img_slices, positions = self._slice_image(image, idx)
28+
all_slices.extend(img_slices)
29+
all_positions.extend(positions)
30+
else:
31+
all_slices.append(image)
32+
all_positions.append((idx, 0, 0))
33+
34+
return all_slices, all_positions
35+
36+
def slice_count(self, image: Image.Image) -> int:
37+
width, height = image.size
38+
if width > height:
39+
slice_size = self._calculate_slice_size(width, "width")
40+
return math.ceil(width / slice_size)
41+
else:
42+
slice_size = self._calculate_slice_size(height, "height")
43+
return math.ceil(height / slice_size)
44+
45+
def _calculate_slice_size(self, dimension: int, dim_type: str) -> int:
46+
min_size = self.slice_min_dims[dim_type]
47+
return max(min_size, (dimension // self.max_slices + 1))
48+
49+
def _slice_image(self, image: Image.Image, idx: int) -> SLICES_TYPE:
50+
width, height = image.size
51+
slices = []
52+
positions = []
53+
54+
if width > height:
55+
slice_size = self._calculate_slice_size(width, "width")
56+
for i, x in enumerate(range(0, width, slice_size)):
57+
slice_end = min(x + slice_size, width)
58+
slices.append(image.crop((x, 0, slice_end, height)))
59+
positions.append((idx, i, 0))
60+
else:
61+
slice_size = self._calculate_slice_size(height, "height")
62+
for i, y in enumerate(range(0, height, slice_size)):
63+
slice_end = min(y + slice_size, height)
64+
slices.append(image.crop((0, y, width, slice_end)))
65+
positions.append((idx, 0, i))
66+
67+
return slices, positions
68+
69+
def join(self, results: List[LayoutResult], tile_positions: List[Tuple[int, int, int]]) -> List[LayoutResult]:
70+
new_results = []
71+
current_result = None
72+
for idx, (result, tile_position) in enumerate(zip(results, tile_positions)):
73+
image_idx, tile_x, tile_y = tile_position
74+
if idx == 0 or image_idx != tile_positions[idx - 1][0]:
75+
if current_result is not None:
76+
new_results.append(current_result)
77+
current_result = result
78+
else:
79+
merge_dir = "width" if tile_x > 0 else "height"
80+
current_result = self.merge_results(current_result, result, merge_dir=merge_dir)
81+
if current_result is not None:
82+
new_results.append(current_result)
83+
return new_results
84+
85+
86+
def merge_results(self, res1: LayoutResult, res2: LayoutResult, merge_dir="width") -> LayoutResult:
87+
new_image_bbox = res1.image_bbox.copy()
88+
to_remove_idxs = set()
89+
if merge_dir == "width":
90+
new_image_bbox[2] += res2.image_bbox[2]
91+
max_position = max([box.position for box in res1.bboxes])
92+
for i, box2 in enumerate(res2.bboxes):
93+
box2.shift(x_shift=res1.image_bbox[2])
94+
box2.position += max_position
95+
for j, box1 in enumerate(res1.bboxes):
96+
if all([
97+
box1.intersection_area(box2, x_margin=.1) > self.merge_tolerance,
98+
(
99+
box1.y_overlap(box2, y_margin=.1) > box1.height // 2 or
100+
box2.y_overlap(box1, y_margin=.1) > box2.height // 2
101+
),
102+
box1.label == box2.label
103+
]):
104+
box1.merge(box2)
105+
to_remove_idxs.add(i)
106+
107+
elif merge_dir == "height":
108+
new_image_bbox[3] += res2.image_bbox[3]
109+
max_position = max([box.position for box in res1.bboxes])
110+
for i, box2 in enumerate(res2.bboxes):
111+
box2.shift(y_shift=res1.image_bbox[3])
112+
box2.position += max_position
113+
for j, box1 in enumerate(res1.bboxes):
114+
if all([
115+
box1.intersection_area(box2, y_margin=.1) > self.merge_tolerance,
116+
(
117+
box1.x_overlap(box2, x_margin=.1) > box1.width // 2 or
118+
box2.x_overlap(box1, x_margin=.1) > box2.width // 2
119+
),
120+
box1.label == box2.label
121+
]):
122+
box1.merge(box2)
123+
to_remove_idxs.add(i)
124+
125+
new_result = LayoutResult(
126+
image_bbox=new_image_bbox,
127+
bboxes=res1.bboxes + [b for i, b in enumerate(res2.bboxes) if i not in to_remove_idxs]
128+
)
129+
return new_result

surya/layout.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from tqdm import tqdm
88

9+
from surya.input.slicing import ImageSlicer
910
from surya.model.layout.config import ID_TO_LABEL
1011
from surya.postprocessing.heatmap import clean_boxes, intersects_other_boxes
1112
from surya.schema import LayoutResult, LayoutBox
@@ -68,10 +69,31 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
6869
if batch_size is None:
6970
batch_size = get_batch_size()
7071

72+
slicer = ImageSlicer(settings.LAYOUT_SLICE_SIZE)
73+
74+
batches = []
75+
img_counts = [slicer.slice_count(image) for image in images]
76+
77+
start_idx = 0
78+
end_idx = 1
79+
while end_idx < len(img_counts):
80+
if any([
81+
sum(img_counts[start_idx:end_idx]) >= batch_size,
82+
sum(img_counts[start_idx:end_idx + 1]) > batch_size,
83+
]):
84+
batches.append((start_idx, end_idx))
85+
start_idx = end_idx
86+
end_idx += 1
87+
88+
if start_idx < len(img_counts):
89+
batches.append((start_idx, len(img_counts)))
90+
7191
results = []
72-
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing layout"):
73-
batch_images = images[i:i+batch_size]
92+
for (start_idx, end_idx) in tqdm(batches, desc="Recognizing layout"):
93+
batch_results = []
94+
batch_images = images[start_idx:end_idx]
7495
batch_images = [image.convert("RGB") for image in batch_images] # also copies the image
96+
batch_images, tile_positions = slicer.slice(batch_images)
7597
current_batch_size = len(batch_images)
7698

7799
orig_sizes = [image.size for image in batch_images]
@@ -84,15 +106,15 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
84106
start_token = [model.config.decoder.bos_token_id] * 7
85107
batch_decoder_input = [
86108
[start_token] + [pause_token] * model.config.decoder.pause_token_count
87-
for j in range(current_batch_size)
109+
for _ in range(current_batch_size)
88110
]
89111
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)
90112
inference_token_count = batch_decoder_input.shape[1]
91113

92114
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1
93115
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
94116

95-
batch_predictions = [[] for _ in range(len(images))]
117+
batch_predictions = [[] for _ in range(current_batch_size)]
96118

97119
with torch.inference_mode():
98120
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values)[0]
@@ -188,5 +210,11 @@ def batch_layout_detection(images: List, model, processor, batch_size=None) -> L
188210
bboxes=boxes,
189211
image_bbox=[0, 0, orig_size[0], orig_size[1]]
190212
)
191-
results.append(result)
213+
batch_results.append(result)
214+
215+
assert len(batch_results) == len(tile_positions)
216+
batch_results = slicer.join(batch_results, tile_positions)
217+
results.extend(batch_results)
218+
219+
assert len(results) == len(images)
192220
return results

surya/schema.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,16 @@ def merge(self, other):
7272
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
7373

7474
def intersection_area(self, other, x_margin=0, y_margin=0):
75-
x_overlap = max(0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin))
76-
y_overlap = max(0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin))
75+
x_overlap = self.x_overlap(other, x_margin)
76+
y_overlap = self.y_overlap(other, y_margin)
7777
return x_overlap * y_overlap
7878

79+
def x_overlap(self, other, x_margin=0):
80+
return max(0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin))
81+
82+
def y_overlap(self, other, y_margin=0):
83+
return max(0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin))
84+
7985
def intersection_pct(self, other, x_margin=0, y_margin=0):
8086
assert 0 <= x_margin <= 1
8187
assert 0 <= y_margin <= 1
@@ -90,6 +96,14 @@ def intersection_pct(self, other, x_margin=0, y_margin=0):
9096
intersection = self.intersection_area(other, x_margin, y_margin)
9197
return intersection / self.area
9298

99+
def shift(self, x_shift: float | None = None, y_shift: float | None = None):
100+
if x_shift is not None:
101+
for corner in self.polygon:
102+
corner[0] += x_shift
103+
if y_shift is not None:
104+
for corner in self.polygon:
105+
corner[1] += y_shift
106+
93107

94108
class Bbox(BaseModel):
95109
bbox: List[float]

surya/settings.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ def TORCH_DEVICE_MODEL(self) -> str:
6565
RECOGNITION_ENCODER_BATCH_DIVISOR: int = 1 # Divisor for batch size in decoder
6666

6767
# Layout
68-
LAYOUT_MODEL_CHECKPOINT: str = "datalab-to/layout_order_hr4"
68+
LAYOUT_MODEL_CHECKPOINT: str = "datalab-to/layout_order_hr3"
6969
LAYOUT_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
70+
LAYOUT_SLICE_SIZE: Dict = {"height": 1200, "width": 1200} # When to start slicing images
7071
LAYOUT_BATCH_SIZE: Optional[int] = None
7172
LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"
7273
LAYOUT_MAX_BOXES: int = 100

0 commit comments

Comments
 (0)