Skip to content

Commit 2751dcf

Browse files
committed
Add faster precision calculation
1 parent daf4e6d commit 2751dcf

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

benchmark/layout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def main():
106106

107107
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
108108
print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.")
109-
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.")
109+
print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.")
110110
print(f"Wrote results to {result_path}")
111111

112112

surya/benchmark/metrics.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ def calculate_coverage(box, other_boxes, penalize_double=False):
5555
return covered_pixels_count / box_area
5656

5757

58+
def calculate_coverage_fast(box, other_boxes, penalize_double=False):
59+
box_area = (box[2] - box[0]) * (box[3] - box[1])
60+
if box_area == 0:
61+
return 0
62+
63+
total_intersect = 0
64+
for other_box in other_boxes:
65+
total_intersect += intersection_area(box, other_box)
66+
67+
return min(1, total_intersect / box_area)
68+
69+
5870
def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
5971
if len(references) == 0:
6072
return {
@@ -68,10 +80,15 @@ def precision_recall(preds, references, threshold=.5, workers=8, penalize_double
6880
"recall": 0,
6981
}
7082

83+
# If we're not penalizing double coverage, we can use a faster calculation
84+
coverage_func = calculate_coverage_fast
85+
if penalize_double:
86+
coverage_func = calculate_coverage
87+
7188
with ProcessPoolExecutor(max_workers=workers) as executor:
72-
precision_func = partial(calculate_coverage, penalize_double=penalize_double)
89+
precision_func = partial(coverage_func, penalize_double=penalize_double)
7390
precision_iou = executor.map(precision_func, preds, repeat(references))
74-
reference_iou = executor.map(calculate_coverage, references, repeat(preds))
91+
reference_iou = executor.map(coverage_func, references, repeat(preds))
7592

7693
precision_classes = [1 if i > threshold else 0 for i in precision_iou]
7794
precision = sum(precision_classes) / len(precision_classes)

surya/postprocessing/fonts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import List
1+
from typing import List, Optional
22
import os
33
import requests
44

55
from surya.settings import settings
66

77

8-
def get_font_path(langs: List[str] | None = None) -> str:
8+
def get_font_path(langs: Optional[List[str]] = None) -> str:
99
font_path = settings.RECOGNITION_RENDER_FONTS["all"]
1010
if langs is not None:
1111
for k in settings.RECOGNITION_RENDER_FONTS:

surya/settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def TORCH_DEVICE_DETECTION(self) -> str:
5050
DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1400 # Height at which to slice images vertically
5151
DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text)
5252
DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank)
53-
DETECTOR_POSTPROCESSING_CPU_WORKERS: int = 8 # Number of workers for postprocessing
53+
DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing
5454

5555
# Text recognition
5656
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec"

0 commit comments

Comments
 (0)