Skip to content

Commit 97f1432

Browse files
authored
Merge pull request #9 from VikParuchuri/dev
Add text recognition (OCR)
2 parents c2067a2 + e9c26ad commit 97f1432

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+3913
-1026
lines changed

.github/workflows/tests.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,14 @@ jobs:
2424
poetry install
2525
poetry remove torch
2626
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
27-
- name: Run benchmark test
27+
- name: Run detection benchmark test
2828
run: |
2929
poetry run python benchmark/detection.py --max 2
30-
poetry run python scripts/verify_benchmark_scores.py results/benchmark/doclaynet_bench/results.json
30+
poetry run python scripts/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
31+
- name: Run recognition benchmark test
32+
run: |
33+
poetry run python benchmark/recognition.py --max 2
34+
poetry run python scripts/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition
3135
3236
3337

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ wandb
88
notebooks
99
results
1010
data
11+
slices
1112

1213
# Byte-compiled / optimized / DLL files
1314
__pycache__/

README.md

+138-47
Large diffs are not rendered by default.

benchmark/detection.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
from surya.benchmark.bbox import get_pdf_lines
77
from surya.benchmark.metrics import precision_recall
8-
from surya.benchmark.tesseract import tesseract_bboxes, tesseract_parallel
9-
from surya.model.segformer import load_model, load_processor
10-
from surya.model.processing import open_pdf, get_page_images
11-
from surya.detection import batch_inference
12-
from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image
8+
from surya.benchmark.tesseract import tesseract_parallel
9+
from surya.model.detection.segformer import load_model, load_processor
10+
from surya.input.processing import open_pdf, get_page_images
11+
from surya.detection import batch_detection
12+
from surya.postprocessing.heatmap import draw_polys_on_image
1313
from surya.postprocessing.util import rescale_bbox
1414
from surya.settings import settings
1515
import os
@@ -42,9 +42,9 @@ def main():
4242
image_sizes = [img.size for img in images]
4343
correct_boxes = get_pdf_lines(args.pdf_path, image_sizes)
4444
else:
45-
pathname = "doclaynet_bench"
45+
pathname = "det_bench"
4646
# These have already been shuffled randomly, so sampling from the start is fine
47-
dataset = datasets.load_dataset(settings.BENCH_DATASET_NAME, split=f"train[:{args.max}]")
47+
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
4848
images = list(dataset["image"])
4949
images = [i.convert("RGB") for i in images]
5050
correct_boxes = []
@@ -54,7 +54,7 @@ def main():
5454
correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes])
5555

5656
start = time.time()
57-
predictions = batch_inference(images, model, processor)
57+
predictions = batch_detection(images, model, processor)
5858
surya_time = time.time() - start
5959

6060
start = time.time()

benchmark/pymupdf_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from surya.benchmark.bbox import get_pdf_lines
55
from surya.postprocessing.heatmap import draw_bboxes_on_image
66

7-
from surya.model.processing import open_pdf, get_page_images
7+
from surya.input.processing import open_pdf, get_page_images
88
from surya.settings import settings
99

1010

benchmark/recognition.py

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import argparse
2+
from collections import defaultdict
3+
4+
from benchmark.scoring import overlap_score
5+
from surya.model.recognition.model import load_model as load_recognition_model
6+
from surya.model.recognition.processor import load_processor as load_recognition_processor
7+
from surya.ocr import run_recognition
8+
from surya.postprocessing.text import draw_text_on_image
9+
from surya.settings import settings
10+
from surya.languages import CODE_TO_LANGUAGE
11+
from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE
12+
import os
13+
import datasets
14+
import json
15+
import time
16+
from tabulate import tabulate
17+
18+
KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"]
19+
20+
21+
def main():
22+
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
23+
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
24+
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None)
25+
parser.add_argument("--debug", type=int, help="Debug level - 1 dumps bad detection info, 2 writes out images.", default=0)
26+
parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False)
27+
parser.add_argument("--langs", type=str, help="Specify certain languages to benchmark.", default=None)
28+
parser.add_argument("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28)
29+
args = parser.parse_args()
30+
31+
rec_model = load_recognition_model()
32+
rec_processor = load_recognition_processor()
33+
34+
split = "train"
35+
if args.max:
36+
split = f"train[:{args.max}]"
37+
38+
dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split)
39+
40+
if args.langs:
41+
langs = args.langs.split(",")
42+
dataset = dataset.filter(lambda x: x["language"] in langs)
43+
44+
images = list(dataset["image"])
45+
images = [i.convert("RGB") for i in images]
46+
bboxes = list(dataset["bboxes"])
47+
line_text = list(dataset["text"])
48+
languages = list(dataset["language"])
49+
50+
print(f"Loaded {len(images)} images. Running OCR...")
51+
52+
lang_list = []
53+
for l in languages:
54+
if not isinstance(l, list):
55+
lang_list.append([l])
56+
else:
57+
lang_list.append(l)
58+
59+
start = time.time()
60+
predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes)
61+
surya_time = time.time() - start
62+
63+
surya_scores = defaultdict(list)
64+
img_surya_scores = []
65+
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
66+
image_score = overlap_score(pred["text_lines"], ref_text)
67+
img_surya_scores.append(image_score)
68+
for l in lang:
69+
surya_scores[CODE_TO_LANGUAGE[l]].append(image_score)
70+
71+
flat_surya_scores = [s for l in surya_scores for s in surya_scores[l]]
72+
benchmark_stats = {
73+
"surya": {
74+
"avg_score": sum(flat_surya_scores) / len(flat_surya_scores),
75+
"lang_scores": {l: sum(scores) / len(scores) for l, scores in surya_scores.items()},
76+
"time_per_img": surya_time / len(images)
77+
}
78+
}
79+
80+
result_path = os.path.join(args.results_dir, "rec_bench")
81+
os.makedirs(result_path, exist_ok=True)
82+
83+
with open(os.path.join(result_path, "surya_scores.json"), "w+") as f:
84+
json.dump(surya_scores, f)
85+
86+
if args.tesseract:
87+
tess_valid = []
88+
tess_langs = []
89+
for idx, lang in enumerate(lang_list):
90+
# Tesseract does not support all languages
91+
tess_lang = surya_lang_to_tesseract(lang[0])
92+
if tess_lang is None:
93+
continue
94+
95+
tess_valid.append(idx)
96+
tess_langs.append(tess_lang)
97+
98+
tess_imgs = [images[i] for i in tess_valid]
99+
tess_bboxes = [bboxes[i] for i in tess_valid]
100+
tess_reference = [line_text[i] for i in tess_valid]
101+
start = time.time()
102+
tess_predictions = tesseract_ocr_parallel(tess_imgs, tess_bboxes, tess_langs, cpus=args.tess_cpus)
103+
tesseract_time = time.time() - start
104+
105+
tess_scores = defaultdict(list)
106+
for idx, (pred, ref_text, lang) in enumerate(zip(tess_predictions, tess_reference, tess_langs)):
107+
image_score = overlap_score(pred, ref_text)
108+
tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score)
109+
110+
flat_tess_scores = [s for l in tess_scores for s in tess_scores[l]]
111+
benchmark_stats["tesseract"] = {
112+
"avg_score": sum(flat_tess_scores) / len(flat_tess_scores),
113+
"lang_scores": {l: sum(scores) / len(scores) for l, scores in tess_scores.items()},
114+
"time_per_img": tesseract_time / len(tess_imgs)
115+
}
116+
117+
with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f:
118+
json.dump(tess_scores, f)
119+
120+
with open(os.path.join(result_path, "results.json"), "w+") as f:
121+
json.dump(benchmark_stats, f)
122+
123+
key_languages = [k for k in KEY_LANGUAGES if k in surya_scores]
124+
table_headers = ["Model", "Time per page (s)", "Avg Score"] + KEY_LANGUAGES
125+
table_data = [
126+
["surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"]] + [benchmark_stats["surya"]["lang_scores"][l] for l in key_languages],
127+
]
128+
if args.tesseract:
129+
table_data.append(
130+
["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages]
131+
)
132+
133+
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
134+
print("Only a few major languages are displayed. See the result path for additional languages.")
135+
136+
if args.debug >= 1:
137+
bad_detections = []
138+
for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)):
139+
if score < .8:
140+
bad_detections.append((idx, lang, score))
141+
print(f"Found {len(bad_detections)} bad detections. Writing to file...")
142+
with open(os.path.join(result_path, "bad_detections.json"), "w+") as f:
143+
json.dump(bad_detections, f)
144+
145+
if args.debug == 2:
146+
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):
147+
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
148+
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
149+
pred_image = draw_text_on_image(bbox, pred["text_lines"], image.size)
150+
pred_image.save(os.path.join(result_path, pred_image_name))
151+
ref_image = draw_text_on_image(bbox, ref_text, image.size)
152+
ref_image.save(os.path.join(result_path, ref_image_name))
153+
image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png"))
154+
155+
print(f"Wrote results to {result_path}")
156+
157+
158+
if __name__ == "__main__":
159+
main()

benchmark/scoring.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import math
2+
from typing import List
3+
4+
from rapidfuzz import fuzz
5+
6+
7+
def overlap_score(pred_lines: List[str], reference_lines: List[str]):
8+
line_scores = []
9+
line_weights = []
10+
for i, pred_line in enumerate(pred_lines):
11+
max_score = 0
12+
line_weight = 1
13+
for j, ref_line in enumerate(reference_lines):
14+
score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
15+
if score > max_score:
16+
max_score = score
17+
line_weight = math.sqrt(len(ref_line))
18+
line_scores.append(max_score)
19+
line_weights.append(line_weight)
20+
line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))]
21+
22+
return sum(line_scores) / sum(line_weights)

benchmark/tesseract_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from surya.benchmark.tesseract import tesseract_bboxes
55
from surya.postprocessing.heatmap import draw_bboxes_on_image
66

7-
from surya.model.processing import open_pdf, get_page_images
7+
from surya.input.processing import open_pdf, get_page_images
88
from surya.settings import settings
99

1010

detect_text.py

+7-59
Original file line numberDiff line numberDiff line change
@@ -3,71 +3,19 @@
33
import json
44
from collections import defaultdict
55

6-
from PIL import Image
7-
8-
from surya.model.segformer import load_model, load_processor
9-
from surya.model.processing import open_pdf, get_page_images
10-
from surya.detection import batch_inference
6+
from surya.input.load import load_from_folder, load_from_file
7+
from surya.model.detection.segformer import load_model, load_processor
8+
from surya.detection import batch_detection
119
from surya.postprocessing.affinity import draw_lines_on_image
12-
from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image
10+
from surya.postprocessing.heatmap import draw_polys_on_image
1311
from surya.settings import settings
1412
import os
15-
import filetype
16-
17-
18-
def get_name_from_path(path):
19-
return os.path.basename(path).split(".")[0]
20-
21-
22-
def load_pdf(pdf_path, max_pages=None):
23-
doc = open_pdf(pdf_path)
24-
page_count = len(doc)
25-
if max_pages:
26-
page_count = min(max_pages, page_count)
27-
28-
page_indices = list(range(page_count))
29-
30-
images = get_page_images(doc, page_indices)
31-
doc.close()
32-
names = [get_name_from_path(pdf_path) for _ in page_indices]
33-
return images, names
34-
35-
36-
def load_image(image_path):
37-
image = Image.open(image_path).convert("RGB")
38-
name = get_name_from_path(image_path)
39-
return [image], [name]
40-
41-
42-
def load_from_file(input_path, max_pages=None):
43-
input_type = filetype.guess(input_path)
44-
if input_type.extension == "pdf":
45-
return load_pdf(input_path, max_pages)
46-
else:
47-
return load_image(input_path)
48-
49-
50-
def load_from_folder(folder_path, max_pages=None):
51-
image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path)]
52-
image_paths = [ip for ip in image_paths if not os.path.isdir(ip) and not ip.startswith(".")]
53-
54-
images = []
55-
names = []
56-
for path in image_paths:
57-
if filetype.guess(path).extension == "pdf":
58-
image, name = load_pdf(path, max_pages)
59-
images.extend(image)
60-
names.extend(name)
61-
else:
62-
image, name = load_image(path)
63-
images.extend(image)
64-
names.extend(name)
65-
return images, names
13+
from tqdm import tqdm
6614

6715

6816
def main():
6917
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
70-
parser.add_argument("input_path", type=str, help="Path to pdf or image file to detect bboxes in.")
18+
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
7119
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
7220
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
7321
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
@@ -84,7 +32,7 @@ def main():
8432
images, names = load_from_file(args.input_path, args.max)
8533
folder_name = os.path.basename(args.input_path).split(".")[0]
8634

87-
predictions = batch_inference(images, model, processor)
35+
predictions = batch_detection(images, model, processor)
8836
result_path = os.path.join(args.results_dir, folder_name)
8937
os.makedirs(result_path, exist_ok=True)
9038

0 commit comments

Comments
 (0)