Skip to content

Commit b8f00bd

Browse files
committed
Finalize integration of reading order model
1 parent 26d9952 commit b8f00bd

File tree

10 files changed

+55
-158
lines changed

10 files changed

+55
-158
lines changed

.github/workflows/tests.yml

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ jobs:
3636
run: |
3737
poetry run python benchmark/layout.py --max 5
3838
poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
39+
- name: Run ordering benchmark text
40+
run: |
41+
poetry run python benchmark/ordering.py --max 5
42+
poetry run python scripts/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
3943
4044
4145

README.md

+32-13
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who
3535
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) |
3636
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) |
3737
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) |
38-
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | -- |
38+
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) |
3939
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) |
40-
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | -- |
40+
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) |
4141

4242
# Installation
4343

@@ -65,11 +65,11 @@ pip install streamlit
6565
surya_gui
6666
```
6767

68-
Pass the `--math` command line argument to use the math detection model instead of the default model. This will detect math better, but will be worse at everything else.
68+
Pass the `--math` command line argument to use the math text detection model instead of the default model. This will detect math better, but will be worse at everything else.
6969

7070
## OCR (text recognition)
7171

72-
You can OCR text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.
72+
This command will write out a json file with the detected text and bboxes:
7373

7474
```shell
7575
surya_ocr DATA_PATH --images --langs hi,en
@@ -117,7 +117,7 @@ predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec
117117

118118
## Text line detection
119119

120-
You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected bboxes.
120+
This command will write out a json file with the detected bboxes.
121121

122122
```shell
123123
surya_detect DATA_PATH --images
@@ -162,7 +162,7 @@ predictions = batch_text_detection([image], model, processor)
162162

163163
## Layout analysis
164164

165-
You can detect the layout of an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected layout.
165+
This command will write out a json file with the detected layout.
166166

167167
```shell
168168
surya_layout DATA_PATH --images
@@ -209,7 +209,7 @@ layout_predictions = batch_layout_detection([image], model, processor, line_pred
209209

210210
## Reading order
211211

212-
You can detect the reading order of an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected reading order and layout.
212+
This command will write out a json file with the detected reading order and layout.
213213

214214
```shell
215215
surya_order DATA_PATH --images
@@ -224,15 +224,14 @@ The `results.json` file will contain a json dictionary where the keys are the in
224224

225225
- `bboxes` - detected bounding boxes for text
226226
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
227-
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
228-
- `confidence` - the confidence of the model in the detected text (0-1). This is currently not very reliable.
229-
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Text`, `Title`.
227+
- `position` - the position in the reading order of the bbox, starting from 0.
228+
- `label` - the label for the bbox. See the layout section of the documentation for a list of potential labels.
230229
- `page` - the page number in the file
231230
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
232231

233232
**Performance tips**
234233

235-
Setting the `ORDER_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `280MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 9GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.
234+
Setting the `ORDER_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `360MB` of VRAM, so very high batch sizes are possible. The default is a batch size `32`, which will use about 11GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `4`.
236235

237236
### From python
238237

@@ -357,6 +356,16 @@ I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/
357356
- Precision - how well the predicted bboxes cover ground truth bboxes
358357
- Recall - how well ground truth bboxes cover predicted bboxes
359358

359+
## Reading Order
360+
361+
75% mean accuracy, and .14 seconds per image on an A6000 GPU. See methodology for notes - this benchmark is not perfect measure of accuracy, and is more useful as a sanity check.
362+
363+
**Methodology**
364+
365+
I benchmarked the layout analysis on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.
366+
367+
The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.
368+
360369
## Running your own benchmarks
361370

362371
You can benchmark the performance of surya on your machine.
@@ -403,6 +412,16 @@ python benchmark/layout.py
403412
- `--debug` will render images with detected text
404413
- `--results_dir` will let you specify a directory to save results to instead of the default one
405414

415+
**Reading Order**
416+
417+
```
418+
python benchmark/ordering.py
419+
```
420+
421+
- `--max` controls how many images to process for the benchmark
422+
- `--debug` will render images with detected text
423+
- `--results_dir` will let you specify a directory to save results to instead of the default one
424+
406425
# Training
407426

408427
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
@@ -411,7 +430,7 @@ Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a m
411430

412431
# Commercial usage
413432

414-
The text detection, layout analysis, and OCR models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period.
433+
All models were trained from scratch, so they're okay for commercial usage. The weights are licensed cc-by-nc-sa-4.0, but I will waive that for any organization under $5M USD in gross revenue in the most recent 12-month period.
415434

416435
If you want to remove the GPL license requirements for inference or use the weights commercially over the revenue limit, please contact me at surya@vikas.sh for dual licensing.
417436

@@ -424,4 +443,4 @@ This work would not have been possible without amazing open source AI work:
424443
- [transformers](https://github.com/huggingface/transformers) from huggingface
425444
- [CRAFT](https://github.com/clovaai/CRAFT-pytorch), a great scene text detection model
426445

427-
Thank you to everyone who makes open source AI possible.
446+
Thank you to everyone who makes open source AI possible.

benchmark/order.py

-111
This file was deleted.

ocr_app.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def layout_detection(img) -> (Image.Image, LayoutResult):
7070
def order_detection(img) -> (Image.Image, OrderResult):
7171
_, layout_pred = layout_detection(img)
7272
bboxes = [l.bbox for l in layout_pred.bboxes]
73-
labels = [l.label for l in layout_pred.bboxes]
74-
pred = batch_ordering([img], [bboxes], order_model, order_processor, labels=[labels])[0]
73+
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
7574
polys = [l.polygon for l in pred.bboxes]
7675
positions = [str(l.position) for l in pred.bboxes]
7776
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)

reading_order.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,11 @@ def main():
4242
line_predictions = batch_text_detection(images, det_model, det_processor)
4343
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
4444
bboxes = []
45-
labels = []
4645
for layout_pred in layout_predictions:
4746
bbox = [l.bbox for l in layout_pred.bboxes]
48-
label = [l.label for l in layout_pred.bboxes]
4947
bboxes.append(bbox)
50-
labels.append(label)
5148

52-
order_predictions = batch_ordering(images, bboxes, model, processor, labels=labels)
49+
order_predictions = batch_ordering(images, bboxes, model, processor)
5350
result_path = os.path.join(args.results_dir, folder_name)
5451
os.makedirs(result_path, exist_ok=True)
5552

scripts/verify_benchmark_scores.py

+8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def verify_rec(data):
2121
raise ValueError("Scores do not meet the required threshold")
2222

2323

24+
def verify_order(data):
25+
score = data["mean_accuracy"]
26+
if score <= 0.9:
27+
raise ValueError("Scores do not meet the required threshold")
28+
29+
2430
def verify_scores(file_path, bench_type):
2531
with open(file_path, 'r') as file:
2632
data = json.load(file)
@@ -31,6 +37,8 @@ def verify_scores(file_path, bench_type):
3137
verify_rec(data)
3238
elif bench_type == "layout":
3339
verify_layout(data)
40+
elif bench_type == "ordering":
41+
verify_order(data)
3442
else:
3543
raise ValueError("Invalid benchmark type")
3644

static/images/nyt_order.jpg

2.08 MB
Loading

static/images/textbook_order.jpg

384 KB
Loading

surya/model/ordering/model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from surya.settings import settings
99

1010

11-
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
11+
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
1212
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
1313

1414
decoder_config = vars(config.decoder)
@@ -24,8 +24,11 @@ def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
2424
AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder)
2525
AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel)
2626

27-
model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config)
27+
model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
2828
assert isinstance(model.decoder, MBartOrder)
2929
assert isinstance(model.encoder, VariableDonutSwinModel)
3030

31+
model = model.to(device)
32+
model = model.eval()
33+
print(f"Loading reading order model {checkpoint} on device {device} with dtype {dtype}")
3134
return model

surya/ordering.py

+4-26
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,11 @@ def rank_elements(arr):
3030
return rank
3131

3232

33-
def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor, labels: Optional[List[List[str]]] = None) -> List[OrderResult]:
33+
def batch_ordering(images: List, bboxes: List[List[List[float]]], model, processor) -> List[OrderResult]:
3434
assert all([isinstance(image, Image.Image) for image in images])
3535
assert len(images) == len(bboxes)
3636
batch_size = get_batch_size()
3737

38-
if labels is not None:
39-
assert len(labels) == len(images)
40-
for l, b in zip(labels, bboxes):
41-
assert len(l) == len(b)
42-
4338
images = [image.convert("RGB") for image in images]
4439

4540
output_order = []
@@ -78,11 +73,12 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process
7873

7974
last_tokens = []
8075
last_token_mask = []
76+
min_val = torch.finfo(model.dtype).min
8177
for j in range(logits.shape[0]):
8278
label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token
8379
new_logits = logits[j, -1].clone()
84-
new_logits[batch_predictions[j]] = -1e9 # Mask out already predicted tokens, we can only predict each token once
85-
new_logits[label_count:] = -1e9 # Mask out all logit positions above the number of bboxes
80+
new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once
81+
new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes
8682
pred = int(torch.argmax(new_logits, dim=-1).item())
8783

8884
# Add one to avoid colliding with the 1000 height/width token for bboxes
@@ -119,24 +115,6 @@ def batch_ordering(images: List, bboxes: List[List[List[float]]], model, process
119115
for box_idx in range(len(row_bboxes)):
120116
ranks[row_pred[box_idx]] = box_idx
121117

122-
if labels is not None:
123-
# This is to force headers/footers into the proper order
124-
row_label = labels[i+j]
125-
combined = [[i, bbox, label, rank] for i, (bbox, label, rank) in enumerate(zip(row_bboxes, row_label, ranks))]
126-
combined = sorted(combined, key=lambda x: x[3])
127-
128-
sorted_boxes = ([row for row in combined if row[2] == "Page-header"] +
129-
[row for row in combined if row[2] not in ["Page-header", "Page-footer"]] +
130-
[row for row in combined if row[2] == "Page-footer"])
131-
132-
# Re-rank after sorting
133-
for rank, row in enumerate(sorted_boxes):
134-
row[3] = rank
135-
136-
sorted_boxes = sorted(sorted_boxes, key=lambda x: x[0])
137-
row_bboxes = [row[1] for row in sorted_boxes]
138-
ranks = [row[3] for row in sorted_boxes]
139-
140118
order_boxes = []
141119
for row_bbox, rank in zip(row_bboxes, ranks):
142120
order_box = OrderBox(

0 commit comments

Comments
 (0)