Skip to content

Commit e0785d1

Browse files
committed
Add in tesseract to benchmark
1 parent 6371081 commit e0785d1

File tree

10 files changed

+273
-40
lines changed

10 files changed

+273
-40
lines changed

.github/workflows/tests.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,14 @@ jobs:
2424
poetry install
2525
poetry remove torch
2626
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
27-
- name: Run benchmark test
27+
- name: Run detection benchmark test
2828
run: |
2929
poetry run python benchmark/detection.py --max 2
30-
poetry run python scripts/verify_benchmark_scores.py results/benchmark/doclaynet_bench/results.json
30+
poetry run python scripts/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
31+
- name: Run recognition benchmark test
32+
run: |
33+
poetry run python benchmark/recognition.py --max 2
34+
poetry run python scripts/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition
3135
3236
3337

README.md

+26-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Surya
22

3-
Surya is a multilingual document OCR toolkit. It can do:
3+
Surya is for multilingual document OCR. It can do:
44

5-
- Accurate line-level text detection in any language
6-
- Text recognition in 90+ languages
5+
- Accurate OCR in 90+ languages
6+
- Line-level text detection in any language
77
- Table and chart detection (coming soon)
88

99
It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
@@ -39,23 +39,23 @@ Install with:
3939
pip install surya-ocr
4040
```
4141

42-
Model weights will automatically download the first time you run surya.
42+
Model weights will automatically download the first time you run surya. Note that this does not work with the latest version of transformers `4.37+` [yet](https://github.com/huggingface/transformers/issues/28846#issuecomment-1926109135), so you will need to keep `4.36.2`, which is installed with surya.
4343

4444
# Usage
4545

4646
- Inspect the settings in `surya/settings.py`. You can override any settings with environment variables.
47-
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. Note that the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly.
47+
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. For text detection, the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly.
4848

4949
## OCR (text recognition)
5050

51-
You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.
51+
You can detect text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.
5252

5353
```
5454
surya_ocr DATA_PATH --images --langs hi,en
5555
```
5656

5757
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
58-
- `--langs` specifies the language(s) to use for OCR. You can comma separate multiple languages. Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`.
58+
- `--langs` specifies the language(s) to use for OCR. You can comma separate multiple languages (I don't recommend using more than `4`). Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`.
5959
- `--lang_file` if you want to use a different language for different PDFs/images, you can specify languages here. The format is a JSON dict with the keys being filenames and the values as a list, like `{"file1.pdf": ["en", "hi"], "file2.pdf": ["en"]}`.
6060
- `--images` will save images of the pages and detected text lines (optional)
6161
- `--results_dir` specifies the directory to save results to instead of the default
@@ -158,15 +158,17 @@ If you want to develop surya, you can install it manually:
158158

159159
- `git clone https://github.com/VikParuchuri/surya.git`
160160
- `cd surya`
161-
- `poetry install` # Installs main and dev dependencies
161+
- `poetry install` - installs main and dev dependencies
162+
- `poetry shell` - activates the virtual environment
162163

163164
# Limitations
164165

165-
- Math will not be detected well with the main model. Use `DETECTOR_MODEL_CHECKPOINT=vikp/line_detector_math` for better results.
166166
- This is specialized for document OCR. It will likely not work on photos or other images.
167167
- It is for printed text, not handwriting.
168168
- The model has trained itself to ignore advertisements.
169169
- You can find language support for OCR in `surya/languages.py`. Text detection should work with any language.
170+
- Math will not be detected well with the main detector model. Use `DETECTOR_MODEL_CHECKPOINT=vikp/line_detector_math` for better results.
171+
170172

171173
# Benchmarks
172174

@@ -207,7 +209,7 @@ Then we calculate precision and recall for the whole dataset.
207209
You can benchmark the performance of surya on your machine.
208210

209211
- Follow the manual install instructions above.
210-
- `poetry install --group dev` # Installs dev dependencies
212+
- `poetry install --group dev` - installs dev dependencies
211213

212214
**Text line detection**
213215

@@ -222,10 +224,23 @@ python benchmark/detection.py --max 256
222224
- `--pdf_path` will let you specify a pdf to benchmark instead of the default data
223225
- `--results_dir` will let you specify a directory to save results to instead of the default one
224226

227+
**Text recognition**
228+
229+
This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl.
230+
231+
```
232+
python benchmark/recognition.py --max 256
233+
```
234+
235+
- `--max` controls how many images to process for the benchmark
236+
- `--debug` will render images with detected text
237+
- `--results_dir` will let you specify a directory to save results to instead of the default one
238+
- `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder.
239+
225240

226241
# Training
227242

228-
The text detection was trained on 4x A6000s for about 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
243+
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
229244

230245
Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).
231246

benchmark/detection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def main():
4242
image_sizes = [img.size for img in images]
4343
correct_boxes = get_pdf_lines(args.pdf_path, image_sizes)
4444
else:
45-
pathname = "doclaynet_bench"
45+
pathname = "det_bench"
4646
# These have already been shuffled randomly, so sampling from the start is fine
4747
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
4848
images = list(dataset["image"])

benchmark/recognition.py

+66-12
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,26 @@
44
from benchmark.scoring import overlap_score
55
from surya.model.recognition.model import load_model as load_recognition_model
66
from surya.model.recognition.processor import load_processor as load_recognition_processor
7-
from surya.ocr import run_ocr, run_recognition
7+
from surya.ocr import run_recognition
88
from surya.postprocessing.text import draw_text_on_image
99
from surya.settings import settings
10-
from surya.languages import CODE_TO_LANGUAGE, is_arabic
11-
import arabic_reshaper
10+
from surya.languages import CODE_TO_LANGUAGE
11+
from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE
1212
import os
1313
import datasets
1414
import json
15+
import time
16+
from tabulate import tabulate
17+
18+
KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"]
1519

1620

1721
def main():
1822
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
1923
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
2024
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None)
2125
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
26+
parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False)
2227
args = parser.parse_args()
2328

2429
rec_model = load_recognition_model()
@@ -44,25 +49,74 @@ def main():
4449
else:
4550
lang_list.append(l)
4651

52+
start = time.time()
4753
predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes)
54+
surya_time = time.time() - start
4855

49-
image_scores = defaultdict(list)
56+
surya_scores = defaultdict(list)
5057
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
51-
if any(is_arabic(l) for l in lang):
52-
ref_text = [arabic_reshaper.reshape(t) for t in ref_text]
53-
pred["text_lines"] = [arabic_reshaper.reshape(t) for t in pred["text_lines"]]
5458
image_score = overlap_score(pred["text_lines"], ref_text)
5559
for l in lang:
56-
image_scores[CODE_TO_LANGUAGE[l]].append(image_score)
57-
58-
image_avgs = {l: sum(scores) / len(scores) for l, scores in image_scores.items()}
59-
print(image_avgs)
60+
surya_scores[CODE_TO_LANGUAGE[l]].append(image_score)
61+
62+
flat_surya_scores = [s for l in surya_scores for s in surya_scores[l]]
63+
benchmark_stats = {
64+
"surya": {
65+
"avg_score": sum(flat_surya_scores) / len(flat_surya_scores),
66+
"lang_scores": {l: sum(scores) / len(scores) for l, scores in surya_scores.items()},
67+
"time_per_img": surya_time / len(images)
68+
}
69+
}
70+
71+
if args.tesseract:
72+
tess_valid = []
73+
tess_langs = []
74+
for idx, lang in enumerate(lang_list):
75+
# Tesseract does not support all languages
76+
tess_lang = surya_lang_to_tesseract(lang[0])
77+
if tess_lang is None:
78+
continue
79+
80+
tess_valid.append(idx)
81+
tess_langs.append(tess_lang)
82+
83+
tess_imgs = [images[i] for i in tess_valid]
84+
tess_bboxes = [bboxes[i] for i in tess_valid]
85+
tess_reference = [line_text[i] for i in tess_valid]
86+
start = time.time()
87+
tess_predictions = tesseract_ocr_parallel(tess_imgs, tess_bboxes, tess_langs)
88+
tesseract_time = time.time() - start
89+
90+
tess_scores = defaultdict(list)
91+
for idx, (pred, ref_text, lang) in enumerate(zip(tess_predictions, tess_reference, tess_langs)):
92+
image_score = overlap_score(pred, ref_text)
93+
tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score)
94+
95+
flat_tess_scores = [s for l in tess_scores for s in tess_scores[l]]
96+
benchmark_stats["tesseract"] = {
97+
"avg_score": sum(flat_tess_scores) / len(flat_tess_scores),
98+
"lang_scores": {l: sum(scores) / len(scores) for l, scores in tess_scores.items()},
99+
"time_per_img": tesseract_time / len(tess_imgs)
100+
}
60101

61102
result_path = os.path.join(args.results_dir, "rec_bench")
62103
os.makedirs(result_path, exist_ok=True)
63104

64105
with open(os.path.join(result_path, "results.json"), "w+") as f:
65-
json.dump(image_scores, f)
106+
json.dump(benchmark_stats, f)
107+
108+
key_languages = [k for k in KEY_LANGUAGES if k in surya_scores]
109+
table_headers = ["Model", "Time per page (s)", "Avg Score"] + KEY_LANGUAGES
110+
table_data = [
111+
["surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"]] + [benchmark_stats["surya"]["lang_scores"][l] for l in key_languages],
112+
]
113+
if args.tesseract:
114+
table_data.append(
115+
["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages]
116+
)
117+
118+
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
119+
print("Only a few major languages are displayed. See the result path for additional languages.")
66120

67121
if args.debug:
68122
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):

benchmark/scoring.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import math
2+
from typing import List
3+
24
from rapidfuzz import fuzz
35

46

5-
def overlap_score(pred_lines, reference_lines):
7+
def overlap_score(pred_lines: List[str], reference_lines: List[str]):
68
line_scores = []
79
line_weights = []
810
for i, pred_line in enumerate(pred_lines):

scripts/verify_benchmark_scores.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,33 @@
22
import argparse
33

44

5-
def verify_scores(file_path):
6-
with open(file_path, 'r') as file:
7-
data = json.load(file)
8-
5+
def verify_det(data):
96
scores = data["metrics"]["surya"]
10-
117
if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
12-
print(scores)
138
raise ValueError("Scores do not meet the required threshold")
149

1510

11+
def verify_rec(data):
12+
scores = data["surya"]
13+
if scores["avg_score"] <= 0.9:
14+
raise ValueError("Scores do not meet the required threshold")
15+
16+
17+
def verify_scores(file_path, bench_type):
18+
with open(file_path, 'r') as file:
19+
data = json.load(file)
20+
21+
if bench_type == "detection":
22+
verify_det(data)
23+
elif bench_type == "recognition":
24+
verify_rec(data)
25+
else:
26+
raise ValueError("Invalid benchmark type")
27+
28+
1629
if __name__ == "__main__":
1730
parser = argparse.ArgumentParser(description="Verify benchmark scores")
1831
parser.add_argument("file_path", type=str, help="Path to the json file")
32+
parser.add_argument("--bench_type", type=str, help="Type of benchmark to verify", default="detection")
1933
args = parser.parse_args()
20-
verify_scores(args.file_path)
34+
verify_scores(args.file_path, args.bench_type)

0 commit comments

Comments
 (0)