1
1
import math
2
+ from collections import defaultdict
2
3
from typing import List , Optional
3
4
from PIL import Image
4
5
import numpy as np
@@ -28,7 +29,7 @@ def bbox_avg(integral_image, x1, y1, x2, y2):
28
29
def get_regions_from_detection_result (detection_result : TextDetectionResult , heatmaps : List [Image .Image ], orig_size , id2label , segment_assignment , vertical_line_width = 20 ) -> List [LayoutBox ]:
29
30
logits = np .stack (heatmaps , axis = 0 )
30
31
vertical_line_bboxes = [line for line in detection_result .vertical_lines ]
31
- line_bboxes = [ line for line in detection_result .bboxes ]
32
+ line_bboxes = detection_result .bboxes
32
33
33
34
# Scale back to processor size
34
35
for line in vertical_line_bboxes :
@@ -51,66 +52,57 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
51
52
logits [i , segment_assignment != i ] = 0
52
53
53
54
detected_boxes = []
54
- done_maps = set ()
55
- for iteration in range (100 ): # detect up to 100 boxes
56
- bbox = None
57
- confidence = None
58
- for heatmap_idx in range (1 , len (id2label )): # Skip the blank class
59
- if heatmap_idx in done_maps :
55
+ for heatmap_idx in range (1 , len (id2label )): # Skip the blank class
56
+ heatmap = logits [heatmap_idx ]
57
+ bboxes = get_detected_boxes (heatmap , text_threshold = .9 , low_text = .8 )
58
+ bboxes = [bbox for bbox in bboxes if bbox .area > 25 ]
59
+ for bb in bboxes :
60
+ bb .fit_to_bounds ([0 , 0 , heatmap .shape [1 ] - 1 , heatmap .shape [0 ] - 1 ])
61
+
62
+ integral_image = compute_integral_image (heatmap )
63
+ bbox_confidences = [bbox_avg (integral_image , * [int (b ) for b in bbox .bbox ]) for bbox in bboxes ]
64
+ for confidence , bbox in zip (bbox_confidences , bboxes ):
65
+ if confidence <= .3 :
60
66
continue
61
- heatmap = logits [heatmap_idx ]
62
- bboxes = get_detected_boxes (heatmap , text_threshold = .9 )
63
- bboxes = [bbox for bbox in bboxes if bbox .area > 25 ]
64
- for bb in bboxes :
65
- bb .fit_to_bounds ([0 , 0 , heatmap .shape [1 ] - 1 , heatmap .shape [0 ] - 1 ])
66
-
67
- if len (bboxes ) == 0 :
68
- done_maps .add (heatmap_idx )
69
- continue
70
-
71
- integral_image = compute_integral_image (heatmap )
72
- bbox_confidences = [bbox_avg (integral_image , * [int (b ) for b in bbox .bbox ]) for bbox in bboxes ]
73
-
74
- max_confidence = max (bbox_confidences )
75
- max_confidence_idx = bbox_confidences .index (max_confidence )
76
- if max_confidence >= .15 and (confidence is None or max_confidence > confidence ):
77
- bbox = LayoutBox (polygon = bboxes [max_confidence_idx ].polygon , label = id2label [heatmap_idx ])
78
- elif max_confidence < .15 :
79
- done_maps .add (heatmap_idx )
80
-
81
- if bbox is None :
82
- break
83
-
84
- # Expand bbox to cover intersecting lines
85
- remove_indices = []
86
- covered_lines = []
67
+ bbox = LayoutBox (polygon = bbox .polygon , label = id2label [heatmap_idx ], confidence = confidence )
68
+ detected_boxes .append (bbox )
69
+
70
+ detected_boxes = sorted (detected_boxes , key = lambda x : x .confidence , reverse = True )
71
+ # Expand bbox to cover intersecting lines
72
+ box_lines = defaultdict (list )
73
+ used_lines = set ()
74
+ for bbox_idx , bbox in enumerate (detected_boxes ):
87
75
for line_idx , line_bbox in enumerate (line_bboxes ):
88
- if line_bbox .intersection_pct (bbox ) >= .5 :
89
- remove_indices .append (line_idx )
90
- covered_lines . append ( line_bbox . bbox )
76
+ if line_bbox .intersection_pct (bbox ) >= .5 and line_idx not in used_lines :
77
+ box_lines [ bbox_idx ] .append (line_bbox . bbox )
78
+ used_lines . add ( line_idx )
91
79
92
- logits [:, int (bbox .bbox [1 ]):int (bbox .bbox [3 ]), int (bbox .bbox [0 ]):int (bbox .bbox [2 ])] = 0 # zero out where the detected bbox is
93
- if len (covered_lines ) == 0 and bbox .label not in ["Picture" , "Formula" ]:
80
+ new_boxes = []
81
+ for bbox_idx , bbox in enumerate (detected_boxes ):
82
+ if bbox_idx not in box_lines and bbox .label not in ["Picture" , "Formula" ]:
94
83
continue
95
84
96
- if len ( covered_lines ) > 0 and bbox .label == "Picture" :
85
+ if bbox_idx in box_lines and bbox .label in [ "Picture" ] :
97
86
bbox .label = "Figure"
98
87
88
+ covered_lines = box_lines [bbox_idx ]
99
89
if len (covered_lines ) > 0 and bbox .label not in ["Picture" ]:
100
90
min_x = min ([line [0 ] for line in covered_lines ])
101
91
min_y = min ([line [1 ] for line in covered_lines ])
102
92
max_x = max ([line [2 ] for line in covered_lines ])
103
93
max_y = max ([line [3 ] for line in covered_lines ])
104
94
105
- min_x_box = min ([b [0 ] for b in bbox .polygon ])
106
- min_y_box = min ([b [1 ] for b in bbox .polygon ])
107
- max_x_box = max ([b [0 ] for b in bbox .polygon ])
108
- max_y_box = max ([b [1 ] for b in bbox .polygon ])
95
+ if bbox .label in ["Figure" , "Table" , "Formula" ]:
96
+ # Figures can tables can contain text, but text isn't the whole area
97
+ min_x_box = min ([b [0 ] for b in bbox .polygon ])
98
+ min_y_box = min ([b [1 ] for b in bbox .polygon ])
99
+ max_x_box = max ([b [0 ] for b in bbox .polygon ])
100
+ max_y_box = max ([b [1 ] for b in bbox .polygon ])
109
101
110
- min_x = min (min_x , min_x_box )
111
- min_y = min (min_y , min_y_box )
112
- max_x = max (max_x , max_x_box )
113
- max_y = max (max_y , max_y_box )
102
+ min_x = min (min_x , min_x_box )
103
+ min_y = min (min_y , min_y_box )
104
+ max_x = max (max_x , max_x_box )
105
+ max_y = max (max_y , max_y_box )
114
106
115
107
bbox .polygon [0 ][0 ] = min_x
116
108
bbox .polygon [0 ][1 ] = min_y
@@ -121,21 +113,16 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
121
113
bbox .polygon [3 ][0 ] = min_x
122
114
bbox .polygon [3 ][1 ] = max_y
123
115
124
- # Remove "used" overlap lines
125
- line_bboxes = [line_bboxes [i ] for i in range (len (line_bboxes )) if i not in remove_indices ]
126
- detected_boxes .append (bbox )
127
-
128
- logits [:, int (bbox .bbox [1 ]):int (bbox .bbox [3 ]), int (bbox .bbox [0 ]):int (bbox .bbox [2 ])] = 0 # zero out where the new box is
116
+ new_boxes .append (bbox )
129
117
130
- if len (line_bboxes ) > 0 :
131
- for bbox in line_bboxes :
132
- detected_boxes .append (LayoutBox (polygon = bbox .polygon , label = "Text" ))
118
+ unused_lines = [ line for idx , line in enumerate (line_bboxes ) if idx not in used_lines ]
119
+ for bbox in unused_lines :
120
+ new_boxes .append (LayoutBox (polygon = bbox .polygon , label = "Text" , confidence = .5 ))
133
121
134
- for bbox in detected_boxes :
122
+ for bbox in new_boxes :
135
123
bbox .rescale (list (reversed (heatmap .shape )), orig_size )
136
124
137
- detected_boxes = [bbox for bbox in detected_boxes if bbox .area > 16 ]
138
- detected_boxes = clean_contained_boxes (detected_boxes )
125
+ detected_boxes = [bbox for bbox in new_boxes if bbox .area > 16 ]
139
126
return detected_boxes
140
127
141
128
0 commit comments