|
1 | 1 | from copy import deepcopy
|
2 |
| -from typing import List |
| 2 | +from typing import List, Optional |
3 | 3 | import torch
|
4 | 4 | from PIL import Image
|
5 | 5 |
|
@@ -30,11 +30,16 @@ def rank_elements(arr):
|
30 | 30 | return rank
|
31 | 31 |
|
32 | 32 |
|
33 |
| -def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor) -> List[OrderResult]: |
| 33 | +def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor, labels: Optional[List[List[str]]] = None) -> List[OrderResult]: |
34 | 34 | assert all([isinstance(image, Image.Image) for image in images])
|
35 | 35 | assert len(images) == len(bboxes)
|
36 | 36 | batch_size = get_batch_size()
|
37 | 37 |
|
| 38 | + if labels is not None: |
| 39 | + assert len(labels) == len(images) |
| 40 | + for l, b in zip(labels, bboxes): |
| 41 | + assert len(l) == len(b) |
| 42 | + |
38 | 43 | images = [image.convert("RGB") for image in images]
|
39 | 44 |
|
40 | 45 | output_order = []
|
@@ -64,9 +69,29 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process
|
64 | 69 | for j in range(logits.shape[0]):
|
65 | 70 | row_logits = logits[j].tolist()
|
66 | 71 | row_bboxes = bboxes[i+j]
|
| 72 | + assert len(row_logits) == len(row_bboxes), "Mismatch between logits and bboxes." |
| 73 | + |
67 | 74 | orig_size = orig_sizes[j]
|
68 | 75 | ranks = rank_elements(row_logits)
|
69 | 76 |
|
| 77 | + if labels is not None: |
| 78 | + # This is to force headers/footers into the proper order |
| 79 | + row_label = labels[i+j] |
| 80 | + combined = [[i, bbox, label, rank] for i, (bbox, label, rank) in enumerate(zip(row_bboxes, row_label, ranks))] |
| 81 | + combined = sorted(combined, key=lambda x: x[3]) |
| 82 | + |
| 83 | + sorted_boxes = ([row for row in combined if row[2] == "Page-header"] + |
| 84 | + [row for row in combined if row[2] not in ["Page-header", "Page-footer"]] + |
| 85 | + [row for row in combined if row[2] == "Page-footer"]) |
| 86 | + |
| 87 | + # Re-rank after sorting |
| 88 | + for rank, row in enumerate(sorted_boxes): |
| 89 | + row[3] = rank |
| 90 | + |
| 91 | + sorted_boxes = sorted(sorted_boxes, key=lambda x: x[0]) |
| 92 | + row_bboxes = [row[1] for row in sorted_boxes] |
| 93 | + ranks = [row[3] for row in sorted_boxes] |
| 94 | + |
70 | 95 | order_boxes = []
|
71 | 96 | for row_bbox, rank in zip(row_bboxes, ranks):
|
72 | 97 | order_box = OrderBox(
|
|
0 commit comments