Skip to content

Commit 9968ad2

Browse files
committed
Resort bboxes based on layout
1 parent 8cac0dc commit 9968ad2

File tree

4 files changed

+35
-6
lines changed

4 files changed

+35
-6
lines changed

ocr_app.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def layout_detection(img) -> (Image.Image, LayoutResult):
7070
def order_detection(img) -> (Image.Image, OrderResult):
7171
_, layout_pred = layout_detection(img)
7272
bboxes = [l.bbox for l in layout_pred.bboxes]
73-
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
73+
labels = [l.label for l in layout_pred.bboxes]
74+
pred = batch_ordering([img], [bboxes], order_model, order_processor, labels=[labels])[0]
7475
polys = [l.polygon for l in pred.bboxes]
7576
positions = [str(l.position) for l in pred.bboxes]
7677
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)

reading_order.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,14 @@ def main():
4242
line_predictions = batch_text_detection(images, det_model, det_processor)
4343
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
4444
bboxes = []
45+
labels = []
4546
for layout_pred in layout_predictions:
4647
bbox = [l.bbox for l in layout_pred.bboxes]
48+
label = [l.label for l in layout_pred.bboxes]
4749
bboxes.append(bbox)
50+
labels.append(label)
4851

49-
order_predictions = batch_ordering(images, bboxes, model, processor)
52+
order_predictions = batch_ordering(images, bboxes, model, processor, labels=labels)
5053
result_path = os.path.join(args.results_dir, folder_name)
5154
os.makedirs(result_path, exist_ok=True)
5255

surya/ordering.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from copy import deepcopy
2-
from typing import List
2+
from typing import List, Optional
33
import torch
44
from PIL import Image
55

@@ -30,11 +30,16 @@ def rank_elements(arr):
3030
return rank
3131

3232

33-
def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor) -> List[OrderResult]:
33+
def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor, labels: Optional[List[List[str]]] = None) -> List[OrderResult]:
3434
assert all([isinstance(image, Image.Image) for image in images])
3535
assert len(images) == len(bboxes)
3636
batch_size = get_batch_size()
3737

38+
if labels is not None:
39+
assert len(labels) == len(images)
40+
for l, b in zip(labels, bboxes):
41+
assert len(l) == len(b)
42+
3843
images = [image.convert("RGB") for image in images]
3944

4045
output_order = []
@@ -64,9 +69,29 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process
6469
for j in range(logits.shape[0]):
6570
row_logits = logits[j].tolist()
6671
row_bboxes = bboxes[i+j]
72+
assert len(row_logits) == len(row_bboxes), "Mismatch between logits and bboxes."
73+
6774
orig_size = orig_sizes[j]
6875
ranks = rank_elements(row_logits)
6976

77+
if labels is not None:
78+
# This is to force headers/footers into the proper order
79+
row_label = labels[i+j]
80+
combined = [[i, bbox, label, rank] for i, (bbox, label, rank) in enumerate(zip(row_bboxes, row_label, ranks))]
81+
combined = sorted(combined, key=lambda x: x[3])
82+
83+
sorted_boxes = ([row for row in combined if row[2] == "Page-header"] +
84+
[row for row in combined if row[2] not in ["Page-header", "Page-footer"]] +
85+
[row for row in combined if row[2] == "Page-footer"])
86+
87+
# Re-rank after sorting
88+
for rank, row in enumerate(sorted_boxes):
89+
row[3] = rank
90+
91+
sorted_boxes = sorted(sorted_boxes, key=lambda x: x[0])
92+
row_bboxes = [row[1] for row in sorted_boxes]
93+
ranks = [row[3] for row in sorted_boxes]
94+
7095
order_boxes = []
7196
for row_bbox, rank in zip(row_bboxes, ranks):
7297
order_box = OrderBox(

surya/settings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ def TORCH_DEVICE_DETECTION(self) -> str:
6868
RECOGNITION_PAD_VALUE: int = 0 # Should be 0 or 255
6969

7070
# Layout
71-
LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout"
71+
LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout2"
7272
LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench"
7373

7474
# Ordering
75-
ORDER_MODEL_CHECKPOINT: str = "vikp/order2"
75+
ORDER_MODEL_CHECKPOINT: str = "vikp/surya_order"
7676
ORDER_IMAGE_SIZE: Dict = {"height": 1024, "width": 1024}
7777
ORDER_MAX_BOXES: int = 256
7878
ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 4 for CPU/MPS, 32 otherwise

0 commit comments

Comments
 (0)