Skip to content

Commit f8188f4

Browse files
committed
Modify prediction logic
1 parent 4701f96 commit f8188f4

File tree

6 files changed

+156
-36
lines changed

6 files changed

+156
-36
lines changed

ocr_app.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def table_recognition(img, highres_img, skip_table_detection: bool) -> (Image.Im
9292
(item.bbox[3] + table_bbox[1])
9393
])
9494
labels.append(item.label)
95-
if hasattr(item, "row_id"):
95+
if "Row" in item.label:
9696
colors.append("blue")
9797
else:
9898
colors.append("red")

surya/model/table_rec/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
initializer_range=0.02,
9494
layer_norm_eps=1e-5,
9595
encoder_length=1024,
96-
use_positional_embeddings=False,
96+
use_positional_embeddings=True,
9797
**kwargs,
9898
):
9999
super().__init__(**kwargs)

surya/model/table_rec/processor.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,15 @@ def resize_polygon(self, polygon, orig_size, new_size):
5454

5555
return polygon
5656

57-
def __call__(self, images: List[PIL.Image.Image] | None, query_items: List[dict], convert_images: bool = True, *args, **kwargs):
57+
def __call__(
58+
self,
59+
images: List[PIL.Image.Image] | None,
60+
query_items: List[dict],
61+
columns: List[dict] | None = None,
62+
convert_images: bool = True,
63+
*args,
64+
**kwargs
65+
):
5866
if convert_images:
5967
assert len(images) == len(query_items)
6068
assert len(images) > 0
@@ -75,6 +83,13 @@ def __call__(self, images: List[PIL.Image.Image] | None, query_items: List[dict]
7583
[self.token_query_end_id] * col_count
7684
])
7785

86+
# Add columns to end of decoder input
87+
if columns:
88+
columns = self.shaper.convert_polygons_to_bboxes(columns)
89+
column_labels = self.shaper.dict_to_labels(columns)
90+
for decoder_box in decoder_input_boxes:
91+
decoder_box += column_labels
92+
7893
input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long)
7994
input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long)
8095

surya/schema.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
from copy import deepcopy
23
from typing import List, Tuple, Any, Optional
34

45
from pydantic import BaseModel, field_validator, computed_field
@@ -71,6 +72,21 @@ def merge(self, other):
7172
y2 = max(self.bbox[3], other.bbox[3])
7273
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
7374

75+
def intersection_polygon(self, other) -> List[List[float]]:
76+
new_poly = []
77+
for i in range(4):
78+
if i == 0:
79+
new_corner = [max(self.polygon[0][0], other.polygon[0][0]), max(self.polygon[0][1], other.polygon[0][1])]
80+
elif i == 1:
81+
new_corner = [min(self.polygon[1][0], other.polygon[1][0]), max(self.polygon[1][1], other.polygon[1][1])]
82+
elif i == 2:
83+
new_corner = [min(self.polygon[2][0], other.polygon[2][0]), min(self.polygon[2][1], other.polygon[2][1])]
84+
elif i == 3:
85+
new_corner = [max(self.polygon[3][0], other.polygon[3][0]), min(self.polygon[3][1], other.polygon[3][1])]
86+
new_poly.append(new_corner)
87+
88+
return new_poly
89+
7490
def intersection_area(self, other, x_margin=0, y_margin=0):
7591
x_overlap = self.x_overlap(other, x_margin)
7692
y_overlap = self.y_overlap(other, y_margin)
@@ -190,10 +206,16 @@ class TableCell(PolygonBox):
190206
row_id: int
191207
colspan: int
192208
within_row_id: int
209+
cell_id: int
210+
rowspan: int | None = None
193211
merge_up: bool = False
194212
merge_down: bool = False
195213
col_id: int | None = None
196214

215+
@property
216+
def label(self):
217+
return f'{self.row_id} {self.rowspan}/{self.colspan}'
218+
197219

198220
class TableRow(PolygonBox):
199221
row_id: int
@@ -213,6 +235,7 @@ def label(self):
213235

214236
class TableResult(BaseModel):
215237
cells: List[TableCell]
238+
unmerged_cells: List[TableCell]
216239
rows: List[TableRow]
217240
cols: List[TableCol]
218241
image_bbox: List[float]

surya/settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
7676
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"
7777

7878
# Table Rec
79-
TABLE_REC_MODEL_CHECKPOINT: str = "datalab-to/table_rec_2_test"
79+
TABLE_REC_MODEL_CHECKPOINT: str = "datalab-to/table_rec_2_test2"
8080
TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
8181
TABLE_REC_MAX_BOXES: int = 150
8282
TABLE_REC_BATCH_SIZE: Optional[int] = None

surya/tables.py

+114-32
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from copy import deepcopy
2+
from itertools import chain
13
from typing import List
24
import torch
35
from PIL import Image
@@ -6,7 +8,7 @@
68
from surya.model.table_rec.columns import find_columns
79
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
810
from surya.model.table_rec.shaper import LabelShaper
9-
from surya.schema import TableResult, TableCell, TableRow
11+
from surya.schema import TableResult, TableCell, TableRow, TableCol, PolygonBox
1012
from surya.settings import settings
1113
from tqdm import tqdm
1214
import numpy as np
@@ -74,7 +76,6 @@ def inference_loop(
7476

7577
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
7678

77-
print(batch_input_ids)
7879
with torch.inference_mode():
7980
token_count = 0
8081
all_done = torch.zeros(current_batch_size, dtype=torch.bool)
@@ -107,12 +108,13 @@ def inference_loop(
107108
elif mode == "regression":
108109
if k == "bbox":
109110
k_logits *= BOX_DIM
111+
k_logits = k_logits.tolist()
110112
elif k == "colspan":
111113
k_logits = k_logits.clamp(min=1)
112-
box_property[k] = k_logits.tolist()
114+
k_logits = int(k_logits.round().item())
115+
box_property[k] = k_logits
113116
box_properties.append(box_property)
114117

115-
print(box_properties[0])
116118
all_done = all_done | torch.tensor(done, dtype=torch.bool)
117119

118120
if all_done.all():
@@ -173,69 +175,149 @@ def batch_table_recognition(images: List, model: TableRecEncoderDecoderModel, pr
173175
with torch.inference_mode():
174176
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state
175177

176-
row_predictions = inference_loop(model, encoder_hidden_states, batch_input_ids, current_batch_size, batch_size)
178+
rowcol_predictions = inference_loop(model, encoder_hidden_states, batch_input_ids, current_batch_size, batch_size)
177179

178180
row_query_items = []
179181
row_encoder_hidden_states = []
180182
idx_map = []
181-
for j, img_predictions in enumerate(row_predictions):
183+
columns = []
184+
for j, img_predictions in enumerate(rowcol_predictions):
182185
for row_prediction in img_predictions:
183186
polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
184-
row_query_items.append({
185-
"polygon": polygon,
186-
"category": CATEGORY_TO_ID["Table-row"],
187-
"colspan": 0,
188-
"merges": 0,
189-
})
190-
row_encoder_hidden_states.append(encoder_hidden_states[j])
191-
idx_map.append(j)
187+
if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]:
188+
row_query_items.append({
189+
"polygon": polygon,
190+
"category": row_prediction["category"],
191+
"colspan": 0,
192+
"merges": 0,
193+
})
194+
row_encoder_hidden_states.append(encoder_hidden_states[j])
195+
idx_map.append(j)
196+
elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]:
197+
columns.append({
198+
"polygon": polygon,
199+
"category": row_prediction["category"],
200+
"colspan": 0,
201+
"merges": 0,
202+
})
192203

193204
row_encoder_hidden_states = torch.stack(row_encoder_hidden_states)
194-
row_inputs = processor(images=None, query_items=row_query_items, convert_images=False)
205+
row_inputs = processor(images=None, query_items=row_query_items, columns=columns, convert_images=False)
195206
row_input_ids = row_inputs["input_ids"].to(model.device)
196207
cell_predictions = []
197-
"""
198208
for j in tqdm(range(0, len(row_input_ids), batch_size), desc="Recognizing tables"):
199209
cell_batch_hidden_states = row_encoder_hidden_states[j:j+batch_size]
200210
cell_batch_input_ids = row_input_ids[j:j+batch_size]
201211
cell_batch_size = len(cell_batch_input_ids)
202212
cell_predictions.extend(
203213
inference_loop(model, cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size)
204214
)
205-
"""
206215

207-
for j, (img_predictions, orig_size) in enumerate(zip(row_predictions, orig_sizes)):
216+
for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)):
208217
row_cell_predictions = [c for i,c in enumerate(cell_predictions) if idx_map[i] == j]
209218
# Each row prediction matches a cell prediction
210-
#assert len(img_predictions) == len(row_cell_predictions)
211219
rows = []
212220
cells = []
213-
for z, row_prediction in enumerate(img_predictions):
221+
columns = []
222+
223+
cell_id = 0
224+
row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]]
225+
col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]]
226+
227+
for z, col_prediction in enumerate(col_predictions):
228+
polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"])
229+
polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
230+
columns.append(
231+
TableCol(
232+
polygon=polygon,
233+
col_id=z
234+
)
235+
)
236+
237+
for z, row_prediction in enumerate(row_predictions):
214238
polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
215239
polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
216-
rows.append(TableRow(
240+
row = TableRow(
217241
polygon=polygon,
218242
row_id=z
219-
))
220-
"""
221-
for l, cell in enumerate(row_cell_predictions[z]):
222-
polygon = shaper.convert_bbox_to_polygon(cell["bbox"])
243+
)
244+
rows.append(row)
245+
246+
spanning_cells = []
247+
for l, spanning_cell in enumerate(row_cell_predictions[z]):
248+
polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"])
223249
polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
224-
cells.append(
250+
spanning_cells.append(
225251
TableCell(
226252
polygon=polygon,
227253
row_id=z,
254+
rowspan=1,
255+
cell_id=cell_id,
228256
within_row_id=l,
229-
colspan=max(1, int(cell["colspan"])),
230-
merge_up=cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]],
231-
merge_down=cell["merges"] in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]],
257+
colspan=max(1, int(spanning_cell["colspan"])),
258+
merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]],
259+
merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]],
232260
)
233261
)
234-
"""
235-
columns = find_columns(rows, cells)
262+
cell_id += 1
263+
264+
265+
used_spanning_cells = set()
266+
for l, col in enumerate(columns):
267+
cell_polygon = row.intersection_polygon(col)
268+
cell_added = False
269+
for zz, spanning_cell in enumerate(spanning_cells):
270+
intersection_pct = PolygonBox(polygon=cell_polygon).intersection_pct(spanning_cell)
271+
if intersection_pct > .5:
272+
cell_added = True
273+
if zz not in used_spanning_cells:
274+
used_spanning_cells.add(zz)
275+
cells.append(spanning_cell)
276+
277+
if not cell_added:
278+
cells.append(
279+
TableCell(
280+
polygon=cell_polygon,
281+
row_id=z,
282+
rowspan=1,
283+
cell_id=cell_id,
284+
within_row_id=l,
285+
colspan=1,
286+
merge_up=False,
287+
merge_down=False,
288+
)
289+
)
290+
cell_id += 1
291+
292+
grid_cells = deepcopy([
293+
[cell for cell in cells if cell.row_id == row.row_id]
294+
for row in rows
295+
])
296+
297+
for z, grid_row in enumerate(grid_cells[1:]):
298+
prev_row = grid_cells[z]
299+
for l, cell in enumerate(grid_row):
300+
if l >= len(prev_row):
301+
continue
302+
303+
above_cell = prev_row[l]
304+
if above_cell.merge_down and cell.merge_up:
305+
above_cell.merge(cell)
306+
above_cell.rowspan += cell.rowspan
307+
grid_row[l] = above_cell
308+
merged_cells_all = list(chain.from_iterable(grid_cells))
309+
used_ids = set()
310+
merged_cells = []
311+
for cell in merged_cells_all:
312+
if cell.cell_id in used_ids:
313+
continue
314+
used_ids.add(cell.cell_id)
315+
merged_cells.append(cell)
316+
236317

237318
result = TableResult(
238-
cells=cells,
319+
cells=merged_cells,
320+
unmerged_cells=cells,
239321
rows=rows,
240322
cols=columns,
241323
image_bbox=[0, 0, orig_size[0], orig_size[1]],

0 commit comments

Comments
 (0)