|
| 1 | +import argparse |
| 2 | +from collections import defaultdict |
| 3 | + |
| 4 | +from benchmark.scoring import overlap_score |
| 5 | +from surya.model.recognition.model import load_model as load_recognition_model |
| 6 | +from surya.model.recognition.processor import load_processor as load_recognition_processor |
| 7 | +from surya.ocr import run_recognition |
| 8 | +from surya.postprocessing.text import draw_text_on_image |
| 9 | +from surya.settings import settings |
| 10 | +from surya.languages import CODE_TO_LANGUAGE |
| 11 | +from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE |
| 12 | +import os |
| 13 | +import datasets |
| 14 | +import json |
| 15 | +import time |
| 16 | +from tabulate import tabulate |
| 17 | + |
| 18 | +KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"] |
| 19 | + |
| 20 | + |
| 21 | +def main(): |
| 22 | + parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.") |
| 23 | + parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) |
| 24 | + parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None) |
| 25 | + parser.add_argument("--debug", type=int, help="Debug level - 1 dumps bad detection info, 2 writes out images.", default=0) |
| 26 | + parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False) |
| 27 | + parser.add_argument("--langs", type=str, help="Specify certain languages to benchmark.", default=None) |
| 28 | + parser.add_argument("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28) |
| 29 | + args = parser.parse_args() |
| 30 | + |
| 31 | + rec_model = load_recognition_model() |
| 32 | + rec_processor = load_recognition_processor() |
| 33 | + |
| 34 | + split = "train" |
| 35 | + if args.max: |
| 36 | + split = f"train[:{args.max}]" |
| 37 | + |
| 38 | + dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split) |
| 39 | + |
| 40 | + if args.langs: |
| 41 | + langs = args.langs.split(",") |
| 42 | + dataset = dataset.filter(lambda x: x["language"] in langs) |
| 43 | + |
| 44 | + images = list(dataset["image"]) |
| 45 | + images = [i.convert("RGB") for i in images] |
| 46 | + bboxes = list(dataset["bboxes"]) |
| 47 | + line_text = list(dataset["text"]) |
| 48 | + languages = list(dataset["language"]) |
| 49 | + |
| 50 | + print(f"Loaded {len(images)} images. Running OCR...") |
| 51 | + |
| 52 | + lang_list = [] |
| 53 | + for l in languages: |
| 54 | + if not isinstance(l, list): |
| 55 | + lang_list.append([l]) |
| 56 | + else: |
| 57 | + lang_list.append(l) |
| 58 | + |
| 59 | + start = time.time() |
| 60 | + predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes) |
| 61 | + surya_time = time.time() - start |
| 62 | + |
| 63 | + surya_scores = defaultdict(list) |
| 64 | + img_surya_scores = [] |
| 65 | + for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)): |
| 66 | + image_score = overlap_score(pred["text_lines"], ref_text) |
| 67 | + img_surya_scores.append(image_score) |
| 68 | + for l in lang: |
| 69 | + surya_scores[CODE_TO_LANGUAGE[l]].append(image_score) |
| 70 | + |
| 71 | + flat_surya_scores = [s for l in surya_scores for s in surya_scores[l]] |
| 72 | + benchmark_stats = { |
| 73 | + "surya": { |
| 74 | + "avg_score": sum(flat_surya_scores) / len(flat_surya_scores), |
| 75 | + "lang_scores": {l: sum(scores) / len(scores) for l, scores in surya_scores.items()}, |
| 76 | + "time_per_img": surya_time / len(images) |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + result_path = os.path.join(args.results_dir, "rec_bench") |
| 81 | + os.makedirs(result_path, exist_ok=True) |
| 82 | + |
| 83 | + with open(os.path.join(result_path, "surya_scores.json"), "w+") as f: |
| 84 | + json.dump(surya_scores, f) |
| 85 | + |
| 86 | + if args.tesseract: |
| 87 | + tess_valid = [] |
| 88 | + tess_langs = [] |
| 89 | + for idx, lang in enumerate(lang_list): |
| 90 | + # Tesseract does not support all languages |
| 91 | + tess_lang = surya_lang_to_tesseract(lang[0]) |
| 92 | + if tess_lang is None: |
| 93 | + continue |
| 94 | + |
| 95 | + tess_valid.append(idx) |
| 96 | + tess_langs.append(tess_lang) |
| 97 | + |
| 98 | + tess_imgs = [images[i] for i in tess_valid] |
| 99 | + tess_bboxes = [bboxes[i] for i in tess_valid] |
| 100 | + tess_reference = [line_text[i] for i in tess_valid] |
| 101 | + start = time.time() |
| 102 | + tess_predictions = tesseract_ocr_parallel(tess_imgs, tess_bboxes, tess_langs, cpus=args.tess_cpus) |
| 103 | + tesseract_time = time.time() - start |
| 104 | + |
| 105 | + tess_scores = defaultdict(list) |
| 106 | + for idx, (pred, ref_text, lang) in enumerate(zip(tess_predictions, tess_reference, tess_langs)): |
| 107 | + image_score = overlap_score(pred, ref_text) |
| 108 | + tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score) |
| 109 | + |
| 110 | + flat_tess_scores = [s for l in tess_scores for s in tess_scores[l]] |
| 111 | + benchmark_stats["tesseract"] = { |
| 112 | + "avg_score": sum(flat_tess_scores) / len(flat_tess_scores), |
| 113 | + "lang_scores": {l: sum(scores) / len(scores) for l, scores in tess_scores.items()}, |
| 114 | + "time_per_img": tesseract_time / len(tess_imgs) |
| 115 | + } |
| 116 | + |
| 117 | + with open(os.path.join(result_path, "tesseract_scores.json"), "w+") as f: |
| 118 | + json.dump(tess_scores, f) |
| 119 | + |
| 120 | + with open(os.path.join(result_path, "results.json"), "w+") as f: |
| 121 | + json.dump(benchmark_stats, f) |
| 122 | + |
| 123 | + key_languages = [k for k in KEY_LANGUAGES if k in surya_scores] |
| 124 | + table_headers = ["Model", "Time per page (s)", "Avg Score"] + KEY_LANGUAGES |
| 125 | + table_data = [ |
| 126 | + ["surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"]] + [benchmark_stats["surya"]["lang_scores"][l] for l in key_languages], |
| 127 | + ] |
| 128 | + if args.tesseract: |
| 129 | + table_data.append( |
| 130 | + ["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages] |
| 131 | + ) |
| 132 | + |
| 133 | + print(tabulate(table_data, headers=table_headers, tablefmt="github")) |
| 134 | + print("Only a few major languages are displayed. See the result path for additional languages.") |
| 135 | + |
| 136 | + if args.debug >= 1: |
| 137 | + bad_detections = [] |
| 138 | + for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)): |
| 139 | + if score < .8: |
| 140 | + bad_detections.append((idx, lang, score)) |
| 141 | + print(f"Found {len(bad_detections)} bad detections. Writing to file...") |
| 142 | + with open(os.path.join(result_path, "bad_detections.json"), "w+") as f: |
| 143 | + json.dump(bad_detections, f) |
| 144 | + |
| 145 | + if args.debug == 2: |
| 146 | + for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)): |
| 147 | + pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png" |
| 148 | + ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png" |
| 149 | + pred_image = draw_text_on_image(bbox, pred["text_lines"], image.size) |
| 150 | + pred_image.save(os.path.join(result_path, pred_image_name)) |
| 151 | + ref_image = draw_text_on_image(bbox, ref_text, image.size) |
| 152 | + ref_image.save(os.path.join(result_path, ref_image_name)) |
| 153 | + image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png")) |
| 154 | + |
| 155 | + print(f"Wrote results to {result_path}") |
| 156 | + |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + main() |
0 commit comments