Skip to content

Commit 3323022

Browse files
authored
Merge pull request #17 from VikParuchuri/rec
Early test version of recognition
2 parents 7612486 + 5b91009 commit 3323022

27 files changed

+3018
-923
lines changed

README.md

+72-14
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
Surya is a multilingual document OCR toolkit. It can do:
44

5-
- Accurate line-level text detection
6-
- Text recognition (coming soon)
5+
- Accurate line-level text detection in any language
6+
- Text recognition in 90+ languages
77
- Table and chart detection (coming soon)
88

9-
It works on a range of documents and languages (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
9+
It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
1010

1111
![New York Times Article Example](static/images/excerpt.png)
1212

@@ -46,6 +46,62 @@ Model weights will automatically download the first time you run surya.
4646
- Inspect the settings in `surya/settings.py`. You can override any settings with environment variables.
4747
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. Note that the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly.
4848

49+
## OCR (text recognition)
50+
51+
You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.
52+
53+
```
54+
surya_ocr DATA_PATH --images --lang hi,en
55+
```
56+
57+
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
58+
- `--lang` specifies the language(s) to use for OCR. You can comma separate multiple languages. Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`.
59+
- `--lang_file` if you want to use a different language for different PDFs/images, you can specify languages here. The format is a JSON dict with the keys being filenames and the values as a list, like `{"file1.pdf": ["en", "hi"], "file2.pdf": ["en"]}`.
60+
- `--images` will save images of the pages and detected text lines (optional)
61+
- `--results_dir` specifies the directory to save results to instead of the default
62+
- `--max` specifies the maximum number of pages to process if you don't want to process everything
63+
- `--start_page` specifies the page number to start processing from
64+
65+
The `results.json` file will contain these keys for each page of the input document(s):
66+
67+
- `text_lines` - the detected text in each line
68+
- `polys` - the polygons for each detected text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
69+
- `bboxes` - the axis-aligned rectangles for each detected text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
70+
- `language` - the languages specified for the page
71+
- `name` - the name of the file
72+
- `page_number` - the page number in the file
73+
74+
**Performance tips**
75+
76+
Setting the `RECOGNITION_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `40MB` of VRAM, so very high batch sizes are possible. The default is a batch size `256`, which will use about 10GB of VRAM.
77+
78+
Depending on your CPU core count, `RECOGNITION_BATCH_SIZE` might make a difference there too - the default CPU batch size is `32`.
79+
80+
81+
### From Python
82+
83+
You can also do OCR from code with:
84+
85+
```
86+
from PIL import Image
87+
from surya.ocr import run_ocr
88+
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
89+
from surya.model.recognition.model import load_model as load_rec_model
90+
from surya.model.recognition.processor import load_processor as load_rec_processor
91+
92+
image = Image.open(IMAGE_PATH)
93+
langs = ["en"] # Replace with your languages
94+
95+
det_processor = load_det_processor()
96+
det_model = load_det_model()
97+
98+
rec_model = load_rec_model()
99+
rec_processor = load_rec_processor()
100+
101+
predictions = run_ocr([image], langs, det_model, det_processor, rec_model, rec_processor)
102+
```
103+
104+
49105
## Text line detection
50106

51107
You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected bboxes, and optionally save images of the pages with the bboxes.
@@ -75,26 +131,23 @@ Depending on your CPU core count, `DETECTOR_BATCH_SIZE` might make a difference
75131

76132
You can adjust `DETECTOR_NMS_THRESHOLD` and `DETECTOR_TEXT_THRESHOLD` if you don't get good results. Try lowering them to detect more text, and vice versa.
77133

134+
78135
### From Python
79136

80137
You can also do text detection from code with:
81138

82139
```
83140
from PIL import Image
84-
from surya.detection import batch_inference
141+
from surya.detection import batch_detection
85142
from surya.model.segformer import load_model, load_processor
86143
87144
image = Image.open(IMAGE_PATH)
88145
model, processor = load_model(), load_processor()
89146
90147
# predictions is a list of dicts, one per image
91-
predictions = batch_inference([image], model, processor)
148+
predictions = batch_detection([image], model, processor)
92149
```
93150

94-
## Text recognition
95-
96-
Coming soon.
97-
98151
## Table and chart detection
99152

100153
Coming soon.
@@ -113,10 +166,14 @@ If you want to develop surya, you can install it manually:
113166
- This is specialized for document OCR. It will likely not work on photos or other images.
114167
- It is for printed text, not handwriting.
115168
- The model has trained itself to ignore advertisements.
116-
- This has worked for every language I've tried, but languages with very different character sets may not work well.
169+
- You can find language support for OCR in `surya/languages.py`. Text detection should work with any language.
117170

118171
# Benchmarks
119172

173+
## OCR
174+
175+
Coming soon.
176+
120177
## Text line detection
121178

122179
![Benchmark chart](static/images/benchmark_chart_small.png)
@@ -168,13 +225,13 @@ python benchmark/detection.py --max 256
168225

169226
# Training
170227

171-
This was trained on 4x A6000s for about 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
228+
The text detection was trained on 4x A6000s for about 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
172229

173-
# Commercial usage
230+
Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).
174231

175-
**Text detection**
232+
# Commercial usage
176233

177-
The text detection model was trained from scratch, so it's okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period.
234+
The text detection and OCR models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period.
178235

179236
If you want to remove the GPL license requirements for inference or use the weights commercially over the revenue limit, please contact me at surya@vikas.sh for dual licensing.
180237

@@ -183,6 +240,7 @@ If you want to remove the GPL license requirements for inference or use the weig
183240
This work would not have been possible without amazing open source AI work:
184241

185242
- [Segformer](https://arxiv.org/pdf/2105.15203.pdf) from NVIDIA
243+
- [Donut](https://github.com/clovaai/donut) from Naver
186244
- [transformers](https://github.com/huggingface/transformers) from huggingface
187245
- [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model
188246

benchmark/detection.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from surya.benchmark.bbox import get_pdf_lines
77
from surya.benchmark.metrics import precision_recall
88
from surya.benchmark.tesseract import tesseract_parallel
9-
from surya.model.segformer import load_model, load_processor
9+
from surya.model.detection.segformer import load_model, load_processor
1010
from surya.input.processing import open_pdf, get_page_images
11-
from surya.detection import batch_inference
11+
from surya.detection import batch_detection
1212
from surya.postprocessing.heatmap import draw_polys_on_image
1313
from surya.postprocessing.util import rescale_bbox
1414
from surya.settings import settings
@@ -44,7 +44,7 @@ def main():
4444
else:
4545
pathname = "doclaynet_bench"
4646
# These have already been shuffled randomly, so sampling from the start is fine
47-
dataset = datasets.load_dataset(settings.BENCH_DATASET_NAME, split=f"train[:{args.max}]")
47+
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
4848
images = list(dataset["image"])
4949
images = [i.convert("RGB") for i in images]
5050
correct_boxes = []
@@ -54,7 +54,7 @@ def main():
5454
correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes])
5555

5656
start = time.time()
57-
predictions = batch_inference(images, model, processor)
57+
predictions = batch_detection(images, model, processor)
5858
surya_time = time.time() - start
5959

6060
start = time.time()

benchmark/recognition.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import argparse
2+
from collections import defaultdict
3+
4+
from benchmark.scoring import overlap_score
5+
from surya.model.recognition.model import load_model as load_recognition_model
6+
from surya.model.recognition.processor import load_processor as load_recognition_processor
7+
from surya.ocr import run_ocr, run_recognition
8+
from surya.postprocessing.text import draw_text_on_image
9+
from surya.settings import settings
10+
from surya.languages import CODE_TO_LANGUAGE, is_arabic
11+
import arabic_reshaper
12+
import os
13+
import datasets
14+
import json
15+
16+
17+
def main():
18+
parser = argparse.ArgumentParser(description="Detect bboxes in a PDF.")
19+
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
20+
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None)
21+
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
22+
args = parser.parse_args()
23+
24+
rec_model = load_recognition_model()
25+
rec_processor = load_recognition_processor()
26+
27+
split = "train"
28+
if args.max:
29+
split = f"train[:{args.max}]"
30+
31+
dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split)
32+
images = list(dataset["image"])
33+
images = [i.convert("RGB") for i in images]
34+
bboxes = list(dataset["bboxes"])
35+
line_text = list(dataset["text"])
36+
languages = list(dataset["language"])
37+
38+
print(f"Loaded {len(images)} images. Running OCR...")
39+
40+
lang_list = []
41+
for l in languages:
42+
if not isinstance(l, list):
43+
lang_list.append([l])
44+
else:
45+
lang_list.append(l)
46+
47+
predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes)
48+
49+
image_scores = defaultdict(list)
50+
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
51+
if any(is_arabic(l) for l in lang):
52+
ref_text = [arabic_reshaper.reshape(t) for t in ref_text]
53+
pred["text_lines"] = [arabic_reshaper.reshape(t) for t in pred["text_lines"]]
54+
image_score = overlap_score(pred["text_lines"], ref_text)
55+
for l in lang:
56+
image_scores[CODE_TO_LANGUAGE[l]].append(image_score)
57+
58+
image_avgs = {l: sum(scores) / len(scores) for l, scores in image_scores.items()}
59+
print(image_avgs)
60+
61+
result_path = os.path.join(args.results_dir, "rec_bench")
62+
os.makedirs(result_path, exist_ok=True)
63+
64+
with open(os.path.join(result_path, "results.json"), "w+") as f:
65+
json.dump(image_scores, f)
66+
67+
if args.debug:
68+
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):
69+
pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png"
70+
ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png"
71+
pred_image = draw_text_on_image(bbox, pred["text_lines"], image.size)
72+
pred_image.save(os.path.join(result_path, pred_image_name))
73+
ref_image = draw_text_on_image(bbox, ref_text, image.size)
74+
ref_image.save(os.path.join(result_path, ref_image_name))
75+
image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png"))
76+
77+
print(f"Wrote results to {result_path}")
78+
79+
80+
if __name__ == "__main__":
81+
main()

benchmark/scoring.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import math
2+
from rapidfuzz import fuzz
3+
4+
5+
def overlap_score(pred_lines, reference_lines):
6+
line_scores = []
7+
line_weights = []
8+
for i, pred_line in enumerate(pred_lines):
9+
max_score = 0
10+
line_weight = 1
11+
for j, ref_line in enumerate(reference_lines):
12+
score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100
13+
if score > max_score:
14+
max_score = score
15+
line_weight = math.sqrt(len(ref_line))
16+
line_scores.append(max_score)
17+
line_weights.append(line_weight)
18+
line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))]
19+
20+
return sum(line_scores) / sum(line_weights)

demo_app.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import gradio as gr
2-
from surya.detection import batch_inference
3-
from surya.model.segformer import load_model, load_processor
2+
from surya.detection import batch_detection
3+
from surya.model.detection.segformer import load_model, load_processor
44
from surya.postprocessing.heatmap import draw_polys_on_image
55

66
model, processor = load_model(), load_processor()
@@ -18,7 +18,7 @@
1818
""".strip()
1919

2020
def text_detection(img):
21-
preds = batch_inference([img], model, processor)[0]
21+
preds = batch_detection([img], model, processor)[0]
2222
img = draw_polys_on_image(preds["polygons"], img)
2323
return img, preds
2424

detect_text.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
from collections import defaultdict
55

66
from surya.input.load import load_from_folder, load_from_file
7-
from surya.model.segformer import load_model, load_processor
8-
from surya.detection import batch_inference
7+
from surya.model.detection.segformer import load_model, load_processor
8+
from surya.detection import batch_detection
99
from surya.postprocessing.affinity import draw_lines_on_image
1010
from surya.postprocessing.heatmap import draw_polys_on_image
1111
from surya.settings import settings
1212
import os
13+
from tqdm import tqdm
1314

1415

1516
def main():
1617
parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).")
17-
parser.add_argument("input_path", type=str, help="Path to pdf or image file to detect bboxes in.")
18+
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.")
1819
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya"))
1920
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
2021
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
@@ -31,7 +32,7 @@ def main():
3132
images, names = load_from_file(args.input_path, args.max)
3233
folder_name = os.path.basename(args.input_path).split(".")[0]
3334

34-
predictions = batch_inference(images, model, processor)
35+
predictions = batch_detection(images, model, processor)
3536
result_path = os.path.join(args.results_dir, folder_name)
3637
os.makedirs(result_path, exist_ok=True)
3738

0 commit comments

Comments
 (0)