Skip to content

Commit e8c98ac

Browse files
authored
Merge pull request #75 from VikParuchuri/dev
Add reading order model
2 parents 3cdc3b6 + 1abd2f0 commit e8c98ac

34 files changed

+1462
-48
lines changed

.github/workflows/tests.yml

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ jobs:
3636
run: |
3737
poetry run python benchmark/layout.py --max 5
3838
poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
39+
- name: Run ordering benchmark text
40+
run: |
41+
poetry run python benchmark/ordering.py --max 5
42+
poetry run python scripts/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
3943
4044
4145

README.md

+106-27
Large diffs are not rendered by default.

benchmark/ordering.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import argparse
2+
import collections
3+
import copy
4+
import json
5+
6+
from surya.benchmark.metrics import precision_recall
7+
from surya.model.ordering.model import load_model
8+
from surya.model.ordering.processor import load_processor
9+
from surya.postprocessing.heatmap import draw_bboxes_on_image
10+
from surya.ordering import batch_ordering
11+
from surya.settings import settings
12+
from surya.benchmark.metrics import rank_accuracy
13+
import os
14+
import time
15+
from tabulate import tabulate
16+
import datasets
17+
18+
19+
def main():
20+
parser = argparse.ArgumentParser(description="Benchmark surya reading order model.")
21+
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
22+
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
23+
args = parser.parse_args()
24+
25+
model = load_model()
26+
processor = load_processor()
27+
28+
pathname = "order_bench"
29+
# These have already been shuffled randomly, so sampling from the start is fine
30+
split = "train"
31+
if args.max is not None:
32+
split = f"train[:{args.max}]"
33+
dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split)
34+
images = list(dataset["image"])
35+
images = [i.convert("RGB") for i in images]
36+
bboxes = list(dataset["bboxes"])
37+
38+
start = time.time()
39+
order_predictions = batch_ordering(images, bboxes, model, processor)
40+
surya_time = time.time() - start
41+
42+
folder_name = os.path.basename(pathname).split(".")[0]
43+
result_path = os.path.join(args.results_dir, folder_name)
44+
os.makedirs(result_path, exist_ok=True)
45+
46+
page_metrics = collections.OrderedDict()
47+
mean_accuracy = 0
48+
for idx, order_pred in enumerate(order_predictions):
49+
row = dataset[idx]
50+
pred_labels = [str(l.position) for l in order_pred.bboxes]
51+
labels = row["labels"]
52+
accuracy = rank_accuracy(pred_labels, labels)
53+
mean_accuracy += accuracy
54+
page_results = {
55+
"accuracy": accuracy,
56+
"box_count": len(labels)
57+
}
58+
59+
page_metrics[idx] = page_results
60+
61+
mean_accuracy /= len(order_predictions)
62+
63+
out_data = {
64+
"time": surya_time,
65+
"mean_accuracy": mean_accuracy,
66+
"page_metrics": page_metrics
67+
}
68+
69+
with open(os.path.join(result_path, "results.json"), "w+") as f:
70+
json.dump(out_data, f, indent=4)
71+
72+
print(f"Mean accuracy is {mean_accuracy:.2f}.")
73+
print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.")
74+
print("Mean accuracy is the % of correct ranking pairs.")
75+
print(f"Wrote results to {result_path}")
76+
77+
78+
if __name__ == "__main__":
79+
main()

ocr_app.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
from surya.model.detection.segformer import load_model, load_processor
1111
from surya.model.recognition.model import load_model as load_rec_model
1212
from surya.model.recognition.processor import load_processor as load_rec_processor
13+
from surya.model.ordering.processor import load_processor as load_order_processor
14+
from surya.model.ordering.model import load_model as load_order_model
15+
from surya.ordering import batch_ordering
1316
from surya.postprocessing.heatmap import draw_polys_on_image
1417
from surya.ocr import run_ocr
1518
from surya.postprocessing.text import draw_text_on_image
1619
from PIL import Image
1720
from surya.languages import CODE_TO_LANGUAGE
1821
from surya.input.langs import replace_lang_with_code
19-
from surya.schema import OCRResult, TextDetectionResult, LayoutResult
22+
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
2023
from surya.settings import settings
2124

2225
parser = argparse.ArgumentParser(description="Run OCR on an image or PDF.")
@@ -43,15 +46,19 @@ def load_rec_cached():
4346
def load_layout_cached():
4447
return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
4548

49+
@st.cache_resource()
50+
def load_order_cached():
51+
return load_order_model(), load_order_processor()
52+
4653

47-
def text_detection(img) -> TextDetectionResult:
54+
def text_detection(img) -> (Image.Image, TextDetectionResult):
4855
pred = batch_text_detection([img], det_model, det_processor)[0]
4956
polygons = [p.polygon for p in pred.bboxes]
5057
det_img = draw_polys_on_image(polygons, img.copy())
5158
return det_img, pred
5259

5360

54-
def layout_detection(img) -> LayoutResult:
61+
def layout_detection(img) -> (Image.Image, LayoutResult):
5562
_, det_pred = text_detection(img)
5663
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
5764
polygons = [p.polygon for p in pred.bboxes]
@@ -60,8 +67,18 @@ def layout_detection(img) -> LayoutResult:
6067
return layout_img, pred
6168

6269

70+
def order_detection(img) -> (Image.Image, OrderResult):
71+
_, layout_pred = layout_detection(img)
72+
bboxes = [l.bbox for l in layout_pred.bboxes]
73+
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
74+
polys = [l.polygon for l in pred.bboxes]
75+
positions = [str(l.position) for l in pred.bboxes]
76+
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)
77+
return order_img, pred
78+
79+
6380
# Function for OCR
64-
def ocr(img, langs: List[str]) -> OCRResult:
81+
def ocr(img, langs: List[str]) -> (Image.Image, OCRResult):
6582
replace_lang_with_code(langs)
6683
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
6784

@@ -101,6 +118,7 @@ def page_count(pdf_file):
101118
det_model, det_processor = load_det_cached()
102119
rec_model, rec_processor = load_rec_cached()
103120
layout_model, layout_processor = load_layout_cached()
121+
order_model, order_processor = load_order_cached()
104122

105123

106124
st.markdown("""
@@ -136,24 +154,28 @@ def page_count(pdf_file):
136154
text_det = st.sidebar.button("Run Text Detection")
137155
text_rec = st.sidebar.button("Run OCR")
138156
layout_det = st.sidebar.button("Run Layout Analysis")
157+
order_det = st.sidebar.button("Run Reading Order")
158+
159+
if pil_image is None:
160+
st.stop()
139161

140162
# Run Text Detection
141-
if text_det and pil_image is not None:
163+
if text_det:
142164
det_img, pred = text_detection(pil_image)
143165
with col1:
144166
st.image(det_img, caption="Detected Text", use_column_width=True)
145167
st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True)
146168

147169

148170
# Run layout
149-
if layout_det and pil_image is not None:
171+
if layout_det:
150172
layout_img, pred = layout_detection(pil_image)
151173
with col1:
152174
st.image(layout_img, caption="Detected Layout", use_column_width=True)
153175
st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True)
154176

155177
# Run OCR
156-
if text_rec and pil_image is not None:
178+
if text_rec:
157179
rec_img, pred = ocr(pil_image, languages)
158180
with col1:
159181
st.image(rec_img, caption="OCR Result", use_column_width=True)
@@ -163,5 +185,11 @@ def page_count(pdf_file):
163185
with text_tab:
164186
st.text("\n".join([p.text for p in pred.text_lines]))
165187

188+
if order_det:
189+
order_img, pred = order_detection(pil_image)
190+
with col1:
191+
st.image(order_img, caption="Reading Order", use_column_width=True)
192+
st.json(pred.model_dump(), expanded=True)
193+
166194
with col2:
167195
st.image(pil_image, caption="Uploaded Image", use_column_width=True)

pyproject.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.3.0"
4-
description = "OCR, layout analysis, and line detection in 90+ languages"
3+
version = "0.4.0"
4+
description = "OCR, layout, reading order, and line detection in 90+ languages"
55
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
66
readme = "README.md"
77
license = "GPL-3.0-or-later"
@@ -15,7 +15,8 @@ include = [
1515
"ocr_text.py",
1616
"ocr_app.py",
1717
"run_ocr_app.py",
18-
"detect_layout.py"
18+
"detect_layout.py",
19+
"reading_order.py",
1920
]
2021

2122
[tool.poetry.dependencies]
@@ -48,6 +49,7 @@ surya_detect = "detect_text:main"
4849
surya_ocr = "ocr_text:main"
4950
surya_layout = "detect_layout:main"
5051
surya_gui = "run_ocr_app:run_app"
52+
surya_order = "reading_order:main"
5153

5254
[build-system]
5355
requires = ["poetry-core"]

reading_order.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import argparse
2+
import copy
3+
import json
4+
from collections import defaultdict
5+
6+
from surya.detection import batch_text_detection
7+
from surya.input.load import load_from_folder, load_from_file
8+
from surya.layout import batch_layout_detection
9+
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
10+
from surya.model.ordering.model import load_model
11+
from surya.model.ordering.processor import load_processor
12+
from surya.ordering import batch_ordering
13+
from surya.postprocessing.heatmap import draw_polys_on_image
14+
from surya.settings import settings
15+
import os
16+
17+
18+
def main():
19+
parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).")
20+
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.")
21+
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
22+
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
23+
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
24+
args = parser.parse_args()
25+
26+
model = load_model()
27+
processor = load_processor()
28+
29+
layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
30+
layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
31+
32+
det_model = load_det_model()
33+
det_processor = load_det_processor()
34+
35+
if os.path.isdir(args.input_path):
36+
images, names = load_from_folder(args.input_path, args.max)
37+
folder_name = os.path.basename(args.input_path)
38+
else:
39+
images, names = load_from_file(args.input_path, args.max)
40+
folder_name = os.path.basename(args.input_path).split(".")[0]
41+
42+
line_predictions = batch_text_detection(images, det_model, det_processor)
43+
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
44+
bboxes = []
45+
for layout_pred in layout_predictions:
46+
bbox = [l.bbox for l in layout_pred.bboxes]
47+
bboxes.append(bbox)
48+
49+
order_predictions = batch_ordering(images, bboxes, model, processor)
50+
result_path = os.path.join(args.results_dir, folder_name)
51+
os.makedirs(result_path, exist_ok=True)
52+
53+
if args.images:
54+
for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)):
55+
polys = [l.polygon for l in order_pred.bboxes]
56+
labels = [str(l.position) for l in order_pred.bboxes]
57+
bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20)
58+
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_order.png"))
59+
60+
predictions_by_page = defaultdict(list)
61+
for idx, (layout_pred, pred, name, image) in enumerate(zip(layout_predictions, order_predictions, names, images)):
62+
out_pred = pred.model_dump()
63+
for bbox, layout_bbox in zip(out_pred["bboxes"], layout_pred.bboxes):
64+
bbox["label"] = layout_bbox.label
65+
66+
out_pred["page"] = len(predictions_by_page[name]) + 1
67+
predictions_by_page[name].append(out_pred)
68+
69+
# Sort in reading order
70+
for name in predictions_by_page:
71+
for page_preds in predictions_by_page[name]:
72+
page_preds["bboxes"] = sorted(page_preds["bboxes"], key=lambda x: x["position"])
73+
74+
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
75+
json.dump(predictions_by_page, f, ensure_ascii=False)
76+
77+
print(f"Wrote results to {result_path}")
78+
79+
80+
if __name__ == "__main__":
81+
main()

scripts/verify_benchmark_scores.py

+8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def verify_rec(data):
2121
raise ValueError("Scores do not meet the required threshold")
2222

2323

24+
def verify_order(data):
25+
score = data["mean_accuracy"]
26+
if score < 0.75:
27+
raise ValueError("Scores do not meet the required threshold")
28+
29+
2430
def verify_scores(file_path, bench_type):
2531
with open(file_path, 'r') as file:
2632
data = json.load(file)
@@ -31,6 +37,8 @@ def verify_scores(file_path, bench_type):
3137
verify_rec(data)
3238
elif bench_type == "layout":
3339
verify_layout(data)
40+
elif bench_type == "ordering":
41+
verify_order(data)
3442
else:
3543
raise ValueError("Invalid benchmark type")
3644

static/images/arabic_reading.jpg

304 KB
Loading

static/images/chi_hind_reading.jpg

482 KB
Loading

static/images/chinese_reading.jpg

328 KB
Loading

static/images/excerpt_reading.jpg

351 KB
Loading

static/images/funsd_layout.jpg

195 KB
Loading

static/images/funsd_reading.jpg

195 KB
Loading

static/images/gcloud_full_langs.png

138 KB
Loading

static/images/gcloud_rec_bench.png

34 KB
Loading

static/images/hindi_reading.jpg

306 KB
Loading

static/images/japanese_reading.jpg

419 KB
Loading

static/images/nyt_order.jpg

2.08 MB
Loading

static/images/paper_reading.jpg

604 KB
Loading

static/images/pres_reading.jpg

524 KB
Loading

static/images/scanned_reading.jpg

865 KB
Loading

static/images/textbook_order.jpg

384 KB
Loading

surya/benchmark/metrics.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,23 @@ def mean_coverage(preds, references):
117117
if len(coverages) == 0:
118118
return 0
119119
coverage = sum(coverages) / len(coverages)
120-
return {"coverage": coverage}
120+
return {"coverage": coverage}
121+
122+
123+
def rank_accuracy(preds, references):
124+
# Preds and references need to be aligned so each position refers to the same bbox
125+
pairs = []
126+
for i, pred in enumerate(preds):
127+
for j, pred2 in enumerate(preds):
128+
if i == j:
129+
continue
130+
pairs.append((i, j, pred > pred2))
131+
132+
# Find how many of the prediction rankings are correct
133+
correct = 0
134+
for i, ref in enumerate(references):
135+
for j, ref2 in enumerate(references):
136+
if (i, j, ref > ref2) in pairs:
137+
correct += 1
138+
139+
return correct / len(pairs)

0 commit comments

Comments
 (0)