@@ -19,24 +19,26 @@ def timeit_wrapper(*args, **kwargs):
19
19
result = func (* args , ** kwargs )
20
20
end_time = time .perf_counter ()
21
21
total_time = end_time - start_time
22
- print (f' Function { func .__name__ } { args } { kwargs } Took { total_time :.4f} seconds' )
22
+ print (f" Function { func .__name__ } { args } { kwargs } Took { total_time :.4f} seconds" )
23
23
return result
24
+
24
25
return timeit_wrapper
25
26
27
+
26
28
# load env var
27
29
load_dotenv ()
28
- TABLE_DETECTION_PORT = env ["TABLE_DETECTION_PORT" ]
29
- TABLE_RECOGNITION_PORT = env ["TABLE_RECOGNITION_PORT" ]
30
- TEXT_DETECTION_PORT = env ["TEXT_DETECTION_PORT" ]
31
- TEXT_RECOGNITION_PORT = env ["TEXT_RECOGNITION_PORT" ]
30
+ TABLE_DETECTION_PORT = env ["TABLE_DETECTION_PORT" ]
31
+ TABLE_RECOGNITION_PORT = env ["TABLE_RECOGNITION_PORT" ]
32
+ TEXT_DETECTION_PORT = env ["TEXT_DETECTION_PORT" ]
33
+ TEXT_RECOGNITION_PORT = env ["TEXT_RECOGNITION_PORT" ]
32
34
33
35
34
36
class Box (BaseModel ):
35
- name : str = "box"
36
- xmin : int
37
- xmax : int
38
- ymin : int
39
- ymax : int
37
+ name : str = "box"
38
+ xmin : int
39
+ xmax : int
40
+ ymin : int
41
+ ymax : int
40
42
41
43
@property
42
44
def width (self ):
@@ -61,38 +63,42 @@ def get_intersection(self, box):
61
63
return (xmax - xmin ) * (ymax - ymin )
62
64
return 0
63
65
66
+
64
67
class Text (Box ):
65
- name : str = "text"
66
- ocr : str = ""
68
+ name : str = "text"
69
+ ocr : str = ""
70
+
67
71
68
72
class Cell (Box ):
69
- name : str = "cell"
70
- texts : List [Text ] = []
73
+ name : str = "cell"
74
+ texts : List [Text ] = []
71
75
72
76
def is_valid (self ):
73
77
return self .width > CELL_MIN_WIDTH and self .height > CELL_MIN_HEIGHT
74
78
75
79
76
80
class Table (Box ):
77
- name : str = "table"
78
- cells : List [Cell ] = []
81
+ name : str = "table"
82
+ cells : List [Cell ] = []
83
+
79
84
80
- def read_tables_from_list (input_list : List [Dict ]) -> List [Table ]:
85
+ def read_tables_from_list (input_list : List [Dict ]) -> List [Table ]:
81
86
tables = []
82
87
for item in input_list :
83
88
if item ["name" ] != "table" :
84
89
continue
85
90
new_table = Table (
86
- xmin = item ["xmin" ],
87
- ymin = item ["ymin" ],
88
- xmax = item ["xmax" ],
89
- ymax = item ["ymax" ]
91
+ xmin = item ["xmin" ], ymin = item ["ymin" ], xmax = item ["xmax" ], ymax = item ["ymax" ]
90
92
)
91
- new_table .cells = [Cell (xmin = i ["xmin" ], ymin = i ["ymin" ], xmax = i ["xmax" ], ymax = i ["ymax" ]) for i in item ["cells" ]]
93
+ new_table .cells = [
94
+ Cell (xmin = i ["xmin" ], ymin = i ["ymin" ], xmax = i ["xmax" ], ymax = i ["ymax" ])
95
+ for i in item ["cells" ]
96
+ ]
92
97
tables .append (new_table )
93
98
return tables
94
99
95
- def read_texts_from_list (input_list : List [Dict ]) -> List [Text ]:
100
+
101
+ def read_texts_from_list (input_list : List [Dict ]) -> List [Text ]:
96
102
texts = []
97
103
for item in input_list :
98
104
if item ["name" ] != "text" :
@@ -103,34 +109,44 @@ def read_texts_from_list(input_list : List[Dict]) -> List[Text]:
103
109
ymin = item ["ymin" ],
104
110
xmax = item ["xmax" ],
105
111
ymax = item ["ymax" ],
106
- ocr = item ["ocr" ]
112
+ ocr = item ["ocr" ],
107
113
)
108
114
)
109
115
return texts
110
116
117
+
111
118
@timeit
112
119
def get_table (image_path ):
113
120
image_name = os .path .basename (image_path )
114
121
url = f"http://localhost:{ TABLE_RECOGNITION_PORT } /ai/infer"
115
- files = [
116
- ('file' ,(image_name ,open (image_path ,'rb' ), mimetypes .guess_type (image_path )[0 ]))
122
+ files = [
123
+ (
124
+ "file" ,
125
+ (image_name , open (image_path , "rb" ), mimetypes .guess_type (image_path )[0 ]),
126
+ )
117
127
]
118
128
response = requests .request ("POST" , url , files = files )
119
129
return response .json ()
120
130
131
+
121
132
@timeit
122
133
def get_ocr (image_path ):
123
134
image_name = os .path .basename (image_path )
124
135
url = f"http://localhost:{ TEXT_RECOGNITION_PORT } /ai/infer"
125
- files = [
126
- ('file' ,(image_name ,open (image_path ,'rb' ), mimetypes .guess_type (image_path )[0 ]))
136
+ files = [
137
+ (
138
+ "file" ,
139
+ (image_name , open (image_path , "rb" ), mimetypes .guess_type (image_path )[0 ]),
140
+ )
127
141
]
128
142
response = requests .request ("POST" , url , files = files )
129
143
return response .json ()
130
144
145
+
131
146
def get_random_color ():
132
147
return tuple ((np .random .random (3 ) * 153 + 102 ).astype (np .uint8 ).tolist ())
133
148
149
+
134
150
def show (img , name = "disp" , width = 1000 ):
135
151
"""
136
152
name: name of window, should be name of img
@@ -143,25 +159,43 @@ def show(img, name="disp", width=1000):
143
159
cv2 .destroyAllWindows ()
144
160
145
161
146
- def draw (image , table_list : List [Table ]):
162
+ def draw (image , table_list : List [Table ]):
147
163
vis_image = image .copy ()
148
164
149
165
# draw cell
150
166
for table in table_list :
151
167
for cell in table .cells :
152
- cv2 .rectangle (vis_image , (cell .xmin , cell .ymin ), (cell .xmax , cell .ymax ), get_random_color (), - 1 )
168
+ cv2 .rectangle (
169
+ vis_image ,
170
+ (cell .xmin , cell .ymin ),
171
+ (cell .xmax , cell .ymax ),
172
+ get_random_color (),
173
+ - 1 ,
174
+ )
153
175
154
176
vis_image = vis_image // 2 + image // 2
155
177
156
178
# draw table
157
179
for table in table_list :
158
- cv2 .rectangle (vis_image , (table .xmin , table .ymin ), (table .xmax , table .ymax ), (0 , 0 , 255 ), 4 )
180
+ cv2 .rectangle (
181
+ vis_image ,
182
+ (table .xmin , table .ymin ),
183
+ (table .xmax , table .ymax ),
184
+ (0 , 0 , 255 ),
185
+ 4 ,
186
+ )
159
187
160
188
# draw text
161
189
for table in table_list :
162
190
for cell in table .cells :
163
191
for text in cell .texts :
164
- cv2 .rectangle (vis_image , (text .xmin , text .ymin ), (text .xmax , text .ymax ), (255 , 0 , 0 ), 2 )
192
+ cv2 .rectangle (
193
+ vis_image ,
194
+ (text .xmin , text .ymin ),
195
+ (text .xmax , text .ymax ),
196
+ (255 , 0 , 0 ),
197
+ 2 ,
198
+ )
165
199
166
200
# put text
167
201
for table in table_list :
@@ -172,43 +206,48 @@ def draw(image, table_list : List[Table]):
172
206
text .ocr ,
173
207
(text .xmin , text .ymin ),
174
208
cv2 .FONT_HERSHEY_SIMPLEX ,
175
- 0.5 , (0 , 255 , 0 ), 1 )
209
+ 0.5 ,
210
+ (0 , 255 , 0 ),
211
+ 1 ,
212
+ )
176
213
return vis_image
177
214
178
215
179
- def draw_text (image , text_list : List [Text ]):
216
+ def draw_text (image , text_list : List [Text ]):
180
217
for text in text_list :
181
- cv2 .rectangle (image , (text .xmin , text .ymin ), (text .xmax , text .ymax ), (255 , 0 , 0 ), 2 )
218
+ cv2 .rectangle (
219
+ image , (text .xmin , text .ymin ), (text .xmax , text .ymax ), (255 , 0 , 0 ), 2
220
+ )
182
221
return image
183
222
184
223
185
- def merge_text_table (tables : List [Table ], texts : List [Text ]):
224
+ def merge_text_table (tables : List [Table ], texts : List [Text ]):
186
225
for table in tables :
187
226
in_table_texts = [t for t in texts if t .get_intersection (table ) > 0 ]
188
227
189
228
for cell in table .cells :
190
- cell .texts = [t for t in in_table_texts if t .get_intersection (cell ) / t .area > 0.4 ]
229
+ cell .texts = [
230
+ t for t in in_table_texts if t .get_intersection (cell ) / t .area > 0.4
231
+ ]
232
+
191
233
192
234
@timeit
193
235
def main ():
194
- image_path = ' /home/luan/research/Go5-Project/sample.jpg'
236
+ image_path = " /home/luan/research/Go5-Project/sample.jpg"
195
237
196
238
# read table
197
- output : List = get_table (image_path )
198
- tables : List [Table ] = read_tables_from_list (output )
239
+ output : List = get_table (image_path )
240
+ tables : List [Table ] = read_tables_from_list (output )
199
241
200
242
# read text
201
- output : List = get_ocr (image_path )
202
- texts : List [Text ] = read_texts_from_list (output )
243
+ output : List = get_ocr (image_path )
244
+ texts : List [Text ] = read_texts_from_list (output )
203
245
204
246
merge_text_table (tables , texts )
205
247
206
248
image = cv2 .imread (image_path )
207
249
show (draw (image , tables ))
208
250
209
251
210
-
211
-
212
-
213
252
if __name__ == "__main__" :
214
253
main ()
0 commit comments