Skip to content

Commit 2cf6a77

Browse files
authored
Merge pull request #307 from tarun-menta/textract
Add Textract OCR Benchmark
2 parents 40302e8 + d2ead6a commit 2cf6a77

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

benchmark/recognition.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from surya.settings import settings
1111
from surya.recognition.languages import CODE_TO_LANGUAGE
1212
from benchmark.utils.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE
13+
from benchmark.utils.textract import textract_ocr_parallel
1314
import os
1415
import datasets
1516
import json
@@ -22,22 +23,24 @@
2223
@click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
2324
@click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None)
2425
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
25-
@click.option("--tesseract", is_flag=True, help="Run tesseract instead of surya.", default=False)
26+
@click.option("--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False)
27+
@click.option("--textract", is_flag=True, help="Run benchmarks on textract.", default=False)
2628
@click.option("--langs", type=str, help="Specify certain languages to benchmark.", default=None)
2729
@click.option("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28)
30+
@click.option("--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28)
2831
@click.option("--specify_language", is_flag=True, help="Pass language codes into the model.", default=False)
29-
def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: str, tess_cpus: int, specify_language: bool):
32+
def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, textract: bool, langs: str, tess_cpus: int, textract_cpus:int, specify_language: bool):
3033
rec_predictor = RecognitionPredictor()
3134

3235
split = "train"
33-
if max_rows:
34-
split = f"train[:{max_rows}]"
35-
3636
dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split)
3737

3838
if langs:
3939
langs = langs.split(",")
4040
dataset = dataset.filter(lambda x: x["language"] in langs, num_proc=4)
41+
42+
if max_rows and max_rows<len(dataset):
43+
dataset = dataset.shuffle().select(range(max_rows))
4144

4245
images = list(dataset["image"])
4346
images = convert_if_not_rgb(images)
@@ -121,6 +124,28 @@ def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: s
121124
with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f:
122125
json.dump(tess_scores, f)
123126

127+
if textract:
128+
start = time.time()
129+
textract_predictions = textract_ocr_parallel(images, cpus=textract_cpus)
130+
textract_time = time.time()-start
131+
132+
textract_scores = defaultdict(list)
133+
for idx, (pred, ref_text, lang) in enumerate(zip(textract_predictions, line_text, lang_list)):
134+
image_score = overlap_score(pred, ref_text)
135+
for l in lang:
136+
textract_scores[CODE_TO_LANGUAGE[l]].append(image_score)
137+
138+
flat_textract_scores = [s for l in textract_scores for s in textract_scores[l]]
139+
benchmark_stats["textract"] = {
140+
"avg_score": sum(flat_textract_scores) / len(flat_textract_scores),
141+
"lang_scores": {l: sum(scores) / len(scores) for l, scores in textract_scores.items()},
142+
"time_per_img": textract_time / len(images)
143+
}
144+
print(len(flat_textract_scores))
145+
146+
with open(os.path.join(result_path, "textract_scores.json"), "w+") as f:
147+
json.dump(textract_scores, f)
148+
124149
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
125150
json.dump(benchmark_stats, f)
126151

@@ -133,6 +158,10 @@ def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: s
133158
table_data.append(
134159
["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages]
135160
)
161+
if textract:
162+
table_data.append(
163+
["textract", benchmark_stats["textract"]["time_per_img"], benchmark_stats["textract"]["avg_score"]] + [benchmark_stats["textract"]["lang_scores"][l] for l in key_languages],
164+
)
136165

137166
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
138167
print("Only a few major languages are displayed. See the result path for additional languages.")

benchmark/utils/textract.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
from concurrent.futures import ThreadPoolExecutor
3+
from tqdm import tqdm
4+
import traceback
5+
6+
from surya.input.processing import slice_bboxes_from_image
7+
from surya.recognition import RecognitionPredictor
8+
9+
from textractor import Textractor
10+
11+
def textract_ocr(extractor:Textractor, img):
12+
try:
13+
document = extractor.detect_document_text(file_source=img)
14+
return [line.text for line in document.lines]
15+
except:
16+
traceback.print_exc()
17+
return [None]
18+
19+
def textract_ocr_parallel(imgs, cpus=None):
20+
extractor = Textractor(profile_name='default')
21+
parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size())
22+
if not cpus:
23+
cpus = os.cpu_count()
24+
parallel_cores = min(parallel_cores, cpus)
25+
26+
with ThreadPoolExecutor(max_workers=parallel_cores) as executor:
27+
textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR")
28+
textract_text = list(textract_text)
29+
return textract_text

0 commit comments

Comments
 (0)