10
10
from surya .settings import settings
11
11
from surya .recognition .languages import CODE_TO_LANGUAGE
12
12
from benchmark .utils .tesseract import tesseract_ocr_parallel , surya_lang_to_tesseract , TESS_CODE_TO_LANGUAGE
13
+ from benchmark .utils .textract import textract_ocr_parallel
13
14
import os
14
15
import datasets
15
16
import json
22
23
@click .option ("--results_dir" , type = str , help = "Path to JSON file with OCR results." , default = os .path .join (settings .RESULT_DIR , "benchmark" ))
23
24
@click .option ("--max_rows" , type = int , help = "Maximum number of pdf pages to OCR." , default = None )
24
25
@click .option ("--debug" , is_flag = True , help = "Enable debug mode." , default = False )
25
- @click .option ("--tesseract" , is_flag = True , help = "Run tesseract instead of surya." , default = False )
26
+ @click .option ("--tesseract" , is_flag = True , help = "Run benchmarks on tesseract." , default = False )
27
+ @click .option ("--textract" , is_flag = True , help = "Run benchmarks on textract." , default = False )
26
28
@click .option ("--langs" , type = str , help = "Specify certain languages to benchmark." , default = None )
27
29
@click .option ("--tess_cpus" , type = int , help = "Number of CPUs to use for tesseract." , default = 28 )
30
+ @click .option ("--textract_cpus" , type = int , help = "Number of CPUs to use for textract." , default = 28 )
28
31
@click .option ("--specify_language" , is_flag = True , help = "Pass language codes into the model." , default = False )
29
- def main (results_dir : str , max_rows : int , debug : bool , tesseract : bool , langs : str , tess_cpus : int , specify_language : bool ):
32
+ def main (results_dir : str , max_rows : int , debug : bool , tesseract : bool , textract : bool , langs : str , tess_cpus : int , textract_cpus : int , specify_language : bool ):
30
33
rec_predictor = RecognitionPredictor ()
31
34
32
35
split = "train"
33
- if max_rows :
34
- split = f"train[:{ max_rows } ]"
35
-
36
36
dataset = datasets .load_dataset (settings .RECOGNITION_BENCH_DATASET_NAME , split = split )
37
37
38
38
if langs :
39
39
langs = langs .split ("," )
40
40
dataset = dataset .filter (lambda x : x ["language" ] in langs , num_proc = 4 )
41
+
42
+ if max_rows and max_rows < len (dataset ):
43
+ dataset = dataset .shuffle ().select (range (max_rows ))
41
44
42
45
images = list (dataset ["image" ])
43
46
images = convert_if_not_rgb (images )
@@ -121,6 +124,28 @@ def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: s
121
124
with open (os .path .join (result_path , "tesseract_scores.json" ), "w+" ) as f :
122
125
json .dump (tess_scores , f )
123
126
127
+ if textract :
128
+ start = time .time ()
129
+ textract_predictions = textract_ocr_parallel (images , cpus = textract_cpus )
130
+ textract_time = time .time ()- start
131
+
132
+ textract_scores = defaultdict (list )
133
+ for idx , (pred , ref_text , lang ) in enumerate (zip (textract_predictions , line_text , lang_list )):
134
+ image_score = overlap_score (pred , ref_text )
135
+ for l in lang :
136
+ textract_scores [CODE_TO_LANGUAGE [l ]].append (image_score )
137
+
138
+ flat_textract_scores = [s for l in textract_scores for s in textract_scores [l ]]
139
+ benchmark_stats ["textract" ] = {
140
+ "avg_score" : sum (flat_textract_scores ) / len (flat_textract_scores ),
141
+ "lang_scores" : {l : sum (scores ) / len (scores ) for l , scores in textract_scores .items ()},
142
+ "time_per_img" : textract_time / len (images )
143
+ }
144
+ print (len (flat_textract_scores ))
145
+
146
+ with open (os .path .join (result_path , "textract_scores.json" ), "w+" ) as f :
147
+ json .dump (textract_scores , f )
148
+
124
149
with open (os .path .join (result_path , "results.json" ), "w+" , encoding = "utf-8" ) as f :
125
150
json .dump (benchmark_stats , f )
126
151
@@ -133,6 +158,10 @@ def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, langs: s
133
158
table_data .append (
134
159
["tesseract" , benchmark_stats ["tesseract" ]["time_per_img" ], benchmark_stats ["tesseract" ]["avg_score" ]] + [benchmark_stats ["tesseract" ]["lang_scores" ].get (l , 0 ) for l in key_languages ]
135
160
)
161
+ if textract :
162
+ table_data .append (
163
+ ["textract" , benchmark_stats ["textract" ]["time_per_img" ], benchmark_stats ["textract" ]["avg_score" ]] + [benchmark_stats ["textract" ]["lang_scores" ][l ] for l in key_languages ],
164
+ )
136
165
137
166
print (tabulate (table_data , headers = table_headers , tablefmt = "github" ))
138
167
print ("Only a few major languages are displayed. See the result path for additional languages." )
0 commit comments