Skip to content

Commit 91bf7fb

Browse files
committed
table recognition on XLA too!
1 parent cd23434 commit 91bf7fb

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

surya/model/table_rec/decoder.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99
from surya.model.table_rec.config import TableRecModelOutput, SuryaTableRecTextEncoderConfig
1010
from surya.settings import settings
1111

12+
try:
13+
import torch_xla.core.xla_model as xm
14+
except:
15+
pass
16+
17+
18+
def mark_step():
19+
if settings.TORCH_DEVICE_MODEL == 'xla':
20+
xm.mark_step()
21+
1222

1323
class LabelEmbedding(nn.Module):
1424
def __init__(self, config):
@@ -97,8 +107,9 @@ def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor):
97107
if self.embed_positions:
98108
for j in range(embedded.shape[0]):
99109
box_start = input_box_counts[j, 0]
100-
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
110+
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
101111
box_count = box_end - box_start
112+
mark_step()
102113
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
103114

104115
return embedded
@@ -178,6 +189,8 @@ def forward(
178189
class_logits=class_logits,
179190
hidden_states=hidden_states,
180191
)
192+
193+
181194
@dataclass
182195
class TextEncoderOutput(CausalLMOutput):
183196
hidden_states: torch.FloatTensor = None
@@ -239,4 +252,4 @@ def forward(
239252

240253
return TextEncoderOutput(
241254
hidden_states=outputs.last_hidden_state,
242-
)
255+
)

surya/recognition.py

+1
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def batch_recognition(images: List[Image.Image], languages: List[List[str] | Non
163163
batch_decoder_input = preds.unsqueeze(1)
164164

165165
for j, (pred, status) in enumerate(zip(preds, all_done)):
166+
mark_step()
166167
if not status:
167168
mark_step()
168169
batch_predictions[j].append(int(pred))

surya/tables.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010
import numpy as np
1111
from surya.model.table_rec.config import SPECIAL_TOKENS
1212

13+
try:
14+
import torch_xla.core.xla_model as xm
15+
except:
16+
pass
17+
18+
19+
def mark_step():
20+
if settings.TORCH_DEVICE_MODEL == 'xla':
21+
xm.mark_step()
22+
1323

1424
def get_batch_size():
1525
batch_size = settings.TABLE_REC_BATCH_SIZE
@@ -60,11 +70,11 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
6070

6171
output_order = []
6272
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"):
63-
batch_table_cells = deepcopy(table_cells[i:i+batch_size])
64-
batch_table_cells = [sort_bboxes(page_bboxes) for page_bboxes in batch_table_cells] # Sort bboxes before passing in
73+
batch_table_cells = deepcopy(table_cells[i:i + batch_size])
74+
batch_table_cells = [sort_bboxes(page_bboxes) for page_bboxes in batch_table_cells] # Sort bboxes before passing in
6575
batch_list_bboxes = [[block["bbox"] for block in page] for page in batch_table_cells]
6676

67-
batch_images = images[i:i+batch_size]
77+
batch_images = images[i:i + batch_size]
6878
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
6979

7080
current_batch_size = len(batch_images)
@@ -84,6 +94,7 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
8494

8595
# Setup inputs for the decoder
8696
batch_decoder_input = [[[model.config.decoder.bos_token_id] * 5] for _ in range(current_batch_size)]
97+
mark_step()
8798
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)
8899
inference_token_count = batch_decoder_input.shape[1]
89100

@@ -140,21 +151,25 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
140151
done = (rowcol_preds == processor.tokenizer.eos_id) | (rowcol_preds == processor.tokenizer.pad_id)
141152
all_done = all_done | done
142153

154+
mark_step()
143155
if all_done.all():
144156
break
145157

146158
batch_decoder_input = torch.cat([box_preds.unsqueeze(1), rowcol_preds.unsqueeze(1).unsqueeze(1)], dim=-1)
147159

148160
for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)):
161+
mark_step()
149162
if not status:
150163
batch_predictions[j].append(pred[0].tolist())
151164

152165
token_count += inference_token_count
153166
inference_token_count = batch_decoder_input.shape[1]
154-
167+
mark_step()
168+
155169
if settings.TABLE_REC_STATIC_CACHE:
156170
batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size)
157171

172+
mark_step()
158173
if settings.TABLE_REC_STATIC_CACHE:
159174
batch_predictions = batch_predictions[:current_batch_size]
160175
batch_table_cells = batch_table_cells[:current_batch_size]
@@ -215,4 +230,4 @@ def batch_table_recognition(images: List, table_cells: List[List[Dict]], model:
215230

216231
del text_encoder_hidden_states
217232

218-
return output_order
233+
return output_order

0 commit comments

Comments
 (0)