Skip to content

Commit 20b7b62

Browse files
committed
Add bad OCR detection to app
1 parent 2525ee0 commit 20b7b62

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

ocr_app.py

+55-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import tempfile
23
from typing import List
34

45
import pypdfium2
@@ -15,6 +16,7 @@
1516
from surya.model.recognition.processor import load_processor as load_rec_processor
1617
from surya.model.table_rec.model import load_model as load_table_model
1718
from surya.model.table_rec.processor import load_processor as load_table_processor
19+
from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
1820
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
1921
from surya.ocr import run_ocr
2022
from surya.postprocessing.text import draw_text_on_image
@@ -24,7 +26,9 @@
2426
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
2527
from surya.settings import settings
2628
from surya.tables import batch_table_recognition
27-
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
29+
from surya.postprocessing.util import rescale_bbox
30+
from pdftext.extraction import plain_text_output
31+
from surya.ocr_error import batch_ocr_error_detection
2832

2933

3034
@st.cache_resource()
@@ -46,6 +50,39 @@ def load_layout_cached():
4650
def load_table_cached():
4751
return load_table_model(), load_table_processor()
4852

53+
@st.cache_resource()
54+
def load_ocr_error_cached():
55+
return load_ocr_error_model(), load_ocr_error_processor()
56+
57+
58+
def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
59+
with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
60+
f.write(pdf_file.getvalue())
61+
f.seek(0)
62+
63+
# Sample the text from the middle of the PDF
64+
page_middle = page_count // 2
65+
page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
66+
text = plain_text_output(f.name, page_range=page_range)
67+
68+
sample_gap = len(text) // max_samples
69+
if len(text) == 0 or sample_gap == 0:
70+
return "This PDF has no text or very little text", ["no text"]
71+
72+
if sample_gap < sample_len:
73+
sample_gap = sample_len
74+
75+
# Split the text into samples for the model
76+
samples = []
77+
for i in range(0, len(text), sample_gap):
78+
samples.append(text[i:i + sample_len])
79+
80+
results = batch_ocr_error_detection(samples, ocr_error_model, ocr_error_processor)
81+
label = "This PDF has good text."
82+
if results.labels.count("bad") / len(results.labels) > .2:
83+
label = "This PDF may have garbled or bad OCR text."
84+
return label, results.labels
85+
4986

5087
def text_detection(img) -> (Image.Image, TextDetectionResult):
5188
pred = batch_text_detection([img], det_model, det_processor)[0]
@@ -139,13 +176,16 @@ def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
139176
)
140177
png = list(renderer)[0]
141178
png_image = png.convert("RGB")
179+
doc.close()
142180
return png_image
143181

144182

145183
@st.cache_data()
146-
def page_count(pdf_file):
184+
def page_counter(pdf_file):
147185
doc = open_pdf(pdf_file)
148-
return len(doc)
186+
doc_len = len(doc)
187+
doc.close()
188+
return doc_len
149189

150190

151191
st.set_page_config(layout="wide")
@@ -155,6 +195,7 @@ def page_count(pdf_file):
155195
rec_model, rec_processor = load_rec_cached()
156196
layout_model, layout_processor = load_layout_cached()
157197
table_model, table_processor = load_table_cached()
198+
ocr_error_model, ocr_error_processor = load_ocr_error_cached()
158199

159200

160201
st.markdown("""
@@ -179,8 +220,9 @@ def page_count(pdf_file):
179220

180221
filetype = in_file.type
181222
whole_image = False
223+
page_count = None
182224
if "pdf" in filetype:
183-
page_count = page_count(in_file)
225+
page_count = page_counter(in_file)
184226
page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
185227

186228
pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
@@ -194,6 +236,7 @@ def page_count(pdf_file):
194236
text_rec = st.sidebar.button("Run OCR")
195237
layout_det = st.sidebar.button("Run Layout Analysis")
196238
table_rec = st.sidebar.button("Run Table Rec")
239+
ocr_errors = st.sidebar.button("Run bad PDF text detection")
197240
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
198241
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")
199242

@@ -233,5 +276,13 @@ def page_count(pdf_file):
233276
st.image(table_img, caption="Table Recognition", use_container_width=True)
234277
st.json([p.model_dump() for p in pred], expanded=True)
235278

279+
if ocr_errors:
280+
if "pdf" not in filetype:
281+
st.error("This feature only works with PDFs.")
282+
label, results = run_ocr_errors(in_file, page_count)
283+
with col1:
284+
st.write(label)
285+
st.json(results)
286+
236287
with col2:
237288
st.image(pil_image, caption="Uploaded Image", use_container_width=True)

surya/settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
8585
COMPILE_TABLE_REC: bool = False
8686

8787
# OCR Error Detection
88-
OCR_ERROR_MODEL_CHECKPOINT: str = "tarun-menta/ocr_error_detection"
88+
OCR_ERROR_MODEL_CHECKPOINT: str = "datalab-to/ocr_error_detection"
8989
OCR_ERROR_BATCH_SIZE: Optional[int] = None
9090
COMPILE_OCR_ERROR: bool = False
9191

0 commit comments

Comments
 (0)