10
10
import numpy as np
11
11
from surya .model .table_rec .config import SPECIAL_TOKENS
12
12
13
+ try :
14
+ import torch_xla .core .xla_model as xm
15
+ except :
16
+ pass
17
+
18
+
19
+ def mark_step ():
20
+ if settings .TORCH_DEVICE_MODEL == 'xla' :
21
+ xm .mark_step ()
22
+
13
23
14
24
def get_batch_size ():
15
25
batch_size = settings .TABLE_REC_BATCH_SIZE
@@ -60,11 +70,11 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
60
70
61
71
output_order = []
62
72
for i in tqdm (range (0 , len (images ), batch_size ), desc = "Recognizing tables" ):
63
- batch_table_cells = deepcopy (table_cells [i :i + batch_size ])
64
- batch_table_cells = [sort_bboxes (page_bboxes ) for page_bboxes in batch_table_cells ] # Sort bboxes before passing in
73
+ batch_table_cells = deepcopy (table_cells [i :i + batch_size ])
74
+ batch_table_cells = [sort_bboxes (page_bboxes ) for page_bboxes in batch_table_cells ] # Sort bboxes before passing in
65
75
batch_list_bboxes = [[block ["bbox" ] for block in page ] for page in batch_table_cells ]
66
76
67
- batch_images = images [i :i + batch_size ]
77
+ batch_images = images [i :i + batch_size ]
68
78
batch_images = [image .convert ("RGB" ) for image in batch_images ] # also copies the images
69
79
70
80
current_batch_size = len (batch_images )
@@ -84,6 +94,7 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
84
94
85
95
# Setup inputs for the decoder
86
96
batch_decoder_input = [[[model .config .decoder .bos_token_id ] * 5 ] for _ in range (current_batch_size )]
97
+ mark_step ()
87
98
batch_decoder_input = torch .tensor (np .stack (batch_decoder_input , axis = 0 ), dtype = torch .long , device = model .device )
88
99
inference_token_count = batch_decoder_input .shape [1 ]
89
100
@@ -140,21 +151,25 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
140
151
done = (rowcol_preds == processor .tokenizer .eos_id ) | (rowcol_preds == processor .tokenizer .pad_id )
141
152
all_done = all_done | done
142
153
154
+ mark_step ()
143
155
if all_done .all ():
144
156
break
145
157
146
158
batch_decoder_input = torch .cat ([box_preds .unsqueeze (1 ), rowcol_preds .unsqueeze (1 ).unsqueeze (1 )], dim = - 1 )
147
159
148
160
for j , (pred , status ) in enumerate (zip (batch_decoder_input , all_done )):
161
+ mark_step ()
149
162
if not status :
150
163
batch_predictions [j ].append (pred [0 ].tolist ())
151
164
152
165
token_count += inference_token_count
153
166
inference_token_count = batch_decoder_input .shape [1 ]
154
-
167
+ mark_step ()
168
+
155
169
if settings .TABLE_REC_STATIC_CACHE :
156
170
batch_decoder_input = pad_to_batch_size (batch_decoder_input , batch_size )
157
171
172
+ mark_step ()
158
173
if settings .TABLE_REC_STATIC_CACHE :
159
174
batch_predictions = batch_predictions [:current_batch_size ]
160
175
batch_table_cells = batch_table_cells [:current_batch_size ]
@@ -215,4 +230,4 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
215
230
216
231
del text_encoder_hidden_states
217
232
218
- return output_order
233
+ return output_order
0 commit comments