|
| 1 | +from copy import deepcopy |
| 2 | +from itertools import chain |
1 | 3 | from typing import List
|
2 | 4 | import torch
|
3 | 5 | from PIL import Image
|
|
6 | 8 | from surya.model.table_rec.columns import find_columns
|
7 | 9 | from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
|
8 | 10 | from surya.model.table_rec.shaper import LabelShaper
|
9 |
| -from surya.schema import TableResult, TableCell, TableRow |
| 11 | +from surya.schema import TableResult, TableCell, TableRow, TableCol, PolygonBox |
10 | 12 | from surya.settings import settings
|
11 | 13 | from tqdm import tqdm
|
12 | 14 | import numpy as np
|
@@ -74,7 +76,6 @@ def inference_loop(
|
74 | 76 |
|
75 | 77 | model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
|
76 | 78 |
|
77 |
| - print(batch_input_ids) |
78 | 79 | with torch.inference_mode():
|
79 | 80 | token_count = 0
|
80 | 81 | all_done = torch.zeros(current_batch_size, dtype=torch.bool)
|
@@ -107,12 +108,13 @@ def inference_loop(
|
107 | 108 | elif mode == "regression":
|
108 | 109 | if k == "bbox":
|
109 | 110 | k_logits *= BOX_DIM
|
| 111 | + k_logits = k_logits.tolist() |
110 | 112 | elif k == "colspan":
|
111 | 113 | k_logits = k_logits.clamp(min=1)
|
112 |
| - box_property[k] = k_logits.tolist() |
| 114 | + k_logits = int(k_logits.round().item()) |
| 115 | + box_property[k] = k_logits |
113 | 116 | box_properties.append(box_property)
|
114 | 117 |
|
115 |
| - print(box_properties[0]) |
116 | 118 | all_done = all_done | torch.tensor(done, dtype=torch.bool)
|
117 | 119 |
|
118 | 120 | if all_done.all():
|
@@ -173,69 +175,149 @@ def batch_table_recognition(images: List, model: TableRecEncoderDecoderModel, pr
|
173 | 175 | with torch.inference_mode():
|
174 | 176 | encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state
|
175 | 177 |
|
176 |
| - row_predictions = inference_loop(model, encoder_hidden_states, batch_input_ids, current_batch_size, batch_size) |
| 178 | + rowcol_predictions = inference_loop(model, encoder_hidden_states, batch_input_ids, current_batch_size, batch_size) |
177 | 179 |
|
178 | 180 | row_query_items = []
|
179 | 181 | row_encoder_hidden_states = []
|
180 | 182 | idx_map = []
|
181 |
| - for j, img_predictions in enumerate(row_predictions): |
| 183 | + columns = [] |
| 184 | + for j, img_predictions in enumerate(rowcol_predictions): |
182 | 185 | for row_prediction in img_predictions:
|
183 | 186 | polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
|
184 |
| - row_query_items.append({ |
185 |
| - "polygon": polygon, |
186 |
| - "category": CATEGORY_TO_ID["Table-row"], |
187 |
| - "colspan": 0, |
188 |
| - "merges": 0, |
189 |
| - }) |
190 |
| - row_encoder_hidden_states.append(encoder_hidden_states[j]) |
191 |
| - idx_map.append(j) |
| 187 | + if row_prediction["category"] == CATEGORY_TO_ID["Table-row"]: |
| 188 | + row_query_items.append({ |
| 189 | + "polygon": polygon, |
| 190 | + "category": row_prediction["category"], |
| 191 | + "colspan": 0, |
| 192 | + "merges": 0, |
| 193 | + }) |
| 194 | + row_encoder_hidden_states.append(encoder_hidden_states[j]) |
| 195 | + idx_map.append(j) |
| 196 | + elif row_prediction["category"] == CATEGORY_TO_ID["Table-column"]: |
| 197 | + columns.append({ |
| 198 | + "polygon": polygon, |
| 199 | + "category": row_prediction["category"], |
| 200 | + "colspan": 0, |
| 201 | + "merges": 0, |
| 202 | + }) |
192 | 203 |
|
193 | 204 | row_encoder_hidden_states = torch.stack(row_encoder_hidden_states)
|
194 |
| - row_inputs = processor(images=None, query_items=row_query_items, convert_images=False) |
| 205 | + row_inputs = processor(images=None, query_items=row_query_items, columns=columns, convert_images=False) |
195 | 206 | row_input_ids = row_inputs["input_ids"].to(model.device)
|
196 | 207 | cell_predictions = []
|
197 |
| - """ |
198 | 208 | for j in tqdm(range(0, len(row_input_ids), batch_size), desc="Recognizing tables"):
|
199 | 209 | cell_batch_hidden_states = row_encoder_hidden_states[j:j+batch_size]
|
200 | 210 | cell_batch_input_ids = row_input_ids[j:j+batch_size]
|
201 | 211 | cell_batch_size = len(cell_batch_input_ids)
|
202 | 212 | cell_predictions.extend(
|
203 | 213 | inference_loop(model, cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size)
|
204 | 214 | )
|
205 |
| - """ |
206 | 215 |
|
207 |
| - for j, (img_predictions, orig_size) in enumerate(zip(row_predictions, orig_sizes)): |
| 216 | + for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)): |
208 | 217 | row_cell_predictions = [c for i,c in enumerate(cell_predictions) if idx_map[i] == j]
|
209 | 218 | # Each row prediction matches a cell prediction
|
210 |
| - #assert len(img_predictions) == len(row_cell_predictions) |
211 | 219 | rows = []
|
212 | 220 | cells = []
|
213 |
| - for z, row_prediction in enumerate(img_predictions): |
| 221 | + columns = [] |
| 222 | + |
| 223 | + cell_id = 0 |
| 224 | + row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]] |
| 225 | + col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]] |
| 226 | + |
| 227 | + for z, col_prediction in enumerate(col_predictions): |
| 228 | + polygon = shaper.convert_bbox_to_polygon(col_prediction["bbox"]) |
| 229 | + polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size) |
| 230 | + columns.append( |
| 231 | + TableCol( |
| 232 | + polygon=polygon, |
| 233 | + col_id=z |
| 234 | + ) |
| 235 | + ) |
| 236 | + |
| 237 | + for z, row_prediction in enumerate(row_predictions): |
214 | 238 | polygon = shaper.convert_bbox_to_polygon(row_prediction["bbox"])
|
215 | 239 | polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
|
216 |
| - rows.append(TableRow( |
| 240 | + row = TableRow( |
217 | 241 | polygon=polygon,
|
218 | 242 | row_id=z
|
219 |
| - )) |
220 |
| - """ |
221 |
| - for l, cell in enumerate(row_cell_predictions[z]): |
222 |
| - polygon = shaper.convert_bbox_to_polygon(cell["bbox"]) |
| 243 | + ) |
| 244 | + rows.append(row) |
| 245 | + |
| 246 | + spanning_cells = [] |
| 247 | + for l, spanning_cell in enumerate(row_cell_predictions[z]): |
| 248 | + polygon = shaper.convert_bbox_to_polygon(spanning_cell["bbox"]) |
223 | 249 | polygon = processor.resize_polygon(polygon, (BOX_DIM, BOX_DIM), orig_size)
|
224 |
| - cells.append( |
| 250 | + spanning_cells.append( |
225 | 251 | TableCell(
|
226 | 252 | polygon=polygon,
|
227 | 253 | row_id=z,
|
| 254 | + rowspan=1, |
| 255 | + cell_id=cell_id, |
228 | 256 | within_row_id=l,
|
229 |
| - colspan=max(1, int(cell["colspan"])), |
230 |
| - merge_up=cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]], |
231 |
| - merge_down=cell["merges"] in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]], |
| 257 | + colspan=max(1, int(spanning_cell["colspan"])), |
| 258 | + merge_up=spanning_cell["merges"] in [MERGE_KEYS["merge_up"], MERGE_KEYS["merge_both"]], |
| 259 | + merge_down=spanning_cell["merges"] in [MERGE_KEYS["merge_down"], MERGE_KEYS["merge_both"]], |
232 | 260 | )
|
233 | 261 | )
|
234 |
| - """ |
235 |
| - columns = find_columns(rows, cells) |
| 262 | + cell_id += 1 |
| 263 | + |
| 264 | + |
| 265 | + used_spanning_cells = set() |
| 266 | + for l, col in enumerate(columns): |
| 267 | + cell_polygon = row.intersection_polygon(col) |
| 268 | + cell_added = False |
| 269 | + for zz, spanning_cell in enumerate(spanning_cells): |
| 270 | + intersection_pct = PolygonBox(polygon=cell_polygon).intersection_pct(spanning_cell) |
| 271 | + if intersection_pct > .5: |
| 272 | + cell_added = True |
| 273 | + if zz not in used_spanning_cells: |
| 274 | + used_spanning_cells.add(zz) |
| 275 | + cells.append(spanning_cell) |
| 276 | + |
| 277 | + if not cell_added: |
| 278 | + cells.append( |
| 279 | + TableCell( |
| 280 | + polygon=cell_polygon, |
| 281 | + row_id=z, |
| 282 | + rowspan=1, |
| 283 | + cell_id=cell_id, |
| 284 | + within_row_id=l, |
| 285 | + colspan=1, |
| 286 | + merge_up=False, |
| 287 | + merge_down=False, |
| 288 | + ) |
| 289 | + ) |
| 290 | + cell_id += 1 |
| 291 | + |
| 292 | + grid_cells = deepcopy([ |
| 293 | + [cell for cell in cells if cell.row_id == row.row_id] |
| 294 | + for row in rows |
| 295 | + ]) |
| 296 | + |
| 297 | + for z, grid_row in enumerate(grid_cells[1:]): |
| 298 | + prev_row = grid_cells[z] |
| 299 | + for l, cell in enumerate(grid_row): |
| 300 | + if l >= len(prev_row): |
| 301 | + continue |
| 302 | + |
| 303 | + above_cell = prev_row[l] |
| 304 | + if above_cell.merge_down and cell.merge_up: |
| 305 | + above_cell.merge(cell) |
| 306 | + above_cell.rowspan += cell.rowspan |
| 307 | + grid_row[l] = above_cell |
| 308 | + merged_cells_all = list(chain.from_iterable(grid_cells)) |
| 309 | + used_ids = set() |
| 310 | + merged_cells = [] |
| 311 | + for cell in merged_cells_all: |
| 312 | + if cell.cell_id in used_ids: |
| 313 | + continue |
| 314 | + used_ids.add(cell.cell_id) |
| 315 | + merged_cells.append(cell) |
| 316 | + |
236 | 317 |
|
237 | 318 | result = TableResult(
|
238 |
| - cells=cells, |
| 319 | + cells=merged_cells, |
| 320 | + unmerged_cells=cells, |
239 | 321 | rows=rows,
|
240 | 322 | cols=columns,
|
241 | 323 | image_bbox=[0, 0, orig_size[0], orig_size[1]],
|
|
0 commit comments