Skip to content

Commit 2a4716e

Browse files
committed
Merge remote-tracking branch 'origin/layout2' into layout2
2 parents 3fe9d16 + 6cc9b23 commit 2a4716e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

surya/ocr.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ def run_recognition(images: List[Image.Image], langs: List[List[str] | None], re
6060
return predictions_by_image
6161

6262

63-
def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None, highres_images: List[Image.Image] | None = None) -> List[OCRResult]:
63+
def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, detection_batch_size=None, recognition_batch_size=None, highres_images: List[Image.Image] | None = None) -> List[OCRResult]:
6464
images = convert_if_not_rgb(images)
6565
highres_images = convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images)
66-
det_predictions = batch_text_detection(images, det_model, det_processor)
66+
det_predictions = batch_text_detection(images, det_model, det_processor, batch_size=detection_batch_size)
6767

6868
all_slices = []
6969
slice_map = []
@@ -82,7 +82,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model,
8282
all_langs.extend([lang] * len(slices))
8383
all_slices.extend(slices)
8484

85-
rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size)
85+
rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=recognition_batch_size)
8686

8787
predictions_by_image = []
8888
slice_start = 0

0 commit comments

Comments
 (0)