Skip to content

Commit c2067a2

Browse files
authored
Merge pull request #8 from VikParuchuri/dev
Allow for non-axis-aligned bboxes
2 parents 3a68240 + 4235436 commit c2067a2

File tree

9 files changed

+146
-26
lines changed

9 files changed

+146
-26
lines changed

README.md

+12-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Surya is named after the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), w
2727
| Presentation | [Image](static/images/pres.png) |
2828
| Scientific Paper | [Image](static/images/paper.png) |
2929
| Scanned Document | [Image](static/images/scanned.png) |
30+
| Scanned Form | [Image](static/images/funsd.png) |
3031

3132
# Installation
3233

@@ -58,7 +59,13 @@ surya_detect DATA_PATH --images
5859
- `--max` specifies the maximum number of pages to process if you don't want to process everything
5960
- `--results_dir` specifies the directory to save results to instead of the default
6061

61-
This has worked with every language I've tried. It will work best with documents, and may not work well with photos or other images. It will also not work well with handwriting.
62+
The `results.json` file will contain these keys for each page of the input document(s):
63+
64+
- `polygons` - polygons for each detected text line (these are more accurate than the bboxes) in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
65+
- `bboxes` - axis-aligned rectangles for each detected text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
66+
- `vertical_lines` - vertical lines detected in the document in (x1, y1, x2, y2) format.
67+
- `horizontal_lines` - horizontal lines detected in the document in (x1, y1, x2, y2) format.
68+
- `page_number` - the page number of the document
6269

6370
**Performance tips**
6471

@@ -102,8 +109,10 @@ If you want to develop surya, you can install it manually:
102109

103110
# Limitations
104111

105-
- This is specialized for document OCR. It will likely not work on photos or other images. It will also not work on handwritten text.
106-
- Does not work well with images that look like ads or other parts of documents that are usually ignored.
112+
- This is specialized for document OCR. It will likely not work on photos or other images.
113+
- It is for printed text, not handwriting.
114+
- The model has trained itself to ignore advertisements.
115+
- This has worked for every language I've tried, but languages with very different character sets may not work well.
107116

108117
# Benchmarks
109118

benchmark/detection.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from surya.model.segformer import load_model, load_processor
1010
from surya.model.processing import open_pdf, get_page_images
1111
from surya.detection import batch_inference
12-
from surya.postprocessing.heatmap import draw_bboxes_on_image
12+
from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image
1313
from surya.postprocessing.util import rescale_bbox
1414
from surya.settings import settings
1515
import os
@@ -68,6 +68,7 @@ def main():
6868
page_metrics = collections.OrderedDict()
6969
for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)):
7070
surya_boxes = sb["bboxes"]
71+
surya_polys = sb["polygons"]
7172

7273
surya_metrics = precision_recall(surya_boxes, cb)
7374
tess_metrics = precision_recall(tb, cb)
@@ -78,7 +79,7 @@ def main():
7879
}
7980

8081
if args.debug:
81-
bbox_image = draw_bboxes_on_image(surya_boxes, copy.deepcopy(images[idx]))
82+
bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx]))
8283
bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png"))
8384

8485
mean_metrics = {}

detect_text.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from surya.model.processing import open_pdf, get_page_images
1010
from surya.detection import batch_inference
1111
from surya.postprocessing.affinity import draw_lines_on_image
12-
from surya.postprocessing.heatmap import draw_bboxes_on_image
12+
from surya.postprocessing.heatmap import draw_bboxes_on_image, draw_polys_on_image
1313
from surya.settings import settings
1414
import os
1515
import filetype
@@ -90,7 +90,7 @@ def main():
9090

9191
if args.images:
9292
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
93-
bbox_image = draw_bboxes_on_image(pred["bboxes"], copy.deepcopy(image))
93+
bbox_image = draw_polys_on_image(pred["polygons"], copy.deepcopy(image))
9494
bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png"))
9595

9696
column_image = draw_lines_on_image(pred["vertical_lines"], copy.deepcopy(image))

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.1.5"
3+
version = "0.1.6"
44
description = "Document OCR models for multilingual text detection and recognition"
55
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
66
readme = "README.md"

static/images/funsd.png

200 KB
Loading

surya/detection.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,13 @@ def batch_inference(images: List, model, processor):
8484
affinity_size = list(reversed(affinity_map.shape))
8585
heatmap_size = list(reversed(heatmap.shape))
8686
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes[i])
87+
bbox_data = [bbox.model_dump() for bbox in bboxes]
8788
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes[i])
8889
horizontal_lines = get_horizontal_lines(affinity_map, affinity_size, orig_sizes[i])
8990

9091
results.append({
91-
"bboxes": bboxes,
92+
"bboxes": [bbd["bbox"] for bbd in bbox_data],
93+
"polygons": [bbd["corners"] for bbd in bbox_data],
9294
"vertical_lines": vertical_lines,
9395
"horizontal_lines": horizontal_lines,
9496
"heatmap": heat_img,

surya/postprocessing/heatmap.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
1+
from typing import List
2+
13
import numpy as np
24
import cv2
35
import math
46
from PIL import ImageDraw
57

68
from surya.postprocessing.util import rescale_bbox
9+
from surya.schema import PolygonBox
710
from surya.settings import settings
811

912

10-
def clean_contained_boxes(boxes):
13+
def clean_contained_boxes(boxes: List[PolygonBox]):
1114
new_boxes = []
12-
for box in boxes:
15+
for box_obj in boxes:
16+
box = box_obj.bbox
1317
contained = False
14-
for other_box in boxes:
18+
for other_box_obj in boxes:
19+
if other_box_obj.corners == box_obj.corners:
20+
continue
21+
22+
other_box = other_box_obj.bbox
1523
if box == other_box:
1624
continue
1725
if box[0] >= other_box[0] and box[1] >= other_box[1] and box[2] <= other_box[2] and box[3] <= other_box[3]:
1826
contained = True
1927
break
2028
if not contained:
21-
new_boxes.append(box)
29+
new_boxes.append(box_obj)
2230
return new_boxes
2331

2432

@@ -93,23 +101,14 @@ def get_detected_boxes(textmap, text_threshold=settings.DETECTOR_TEXT_THRESHOLD,
93101
textmap = textmap.astype(np.float32)
94102
boxes, labels = detect_boxes(textmap, text_threshold, low_text)
95103
# From point form to box form
96-
boxes = [
97-
[box[0][0], box[0][1], box[1][0], box[2][1]]
98-
for box in boxes
99-
]
100-
101-
# Ensure correct box format
102-
for box in boxes:
103-
if box[0] > box[2]:
104-
box[0], box[2] = box[2], box[0]
105-
if box[1] > box[3]:
106-
box[1], box[3] = box[3], box[1]
104+
boxes = [PolygonBox(corners=box) for box in boxes]
107105
return boxes
108106

109107

110108
def get_and_clean_boxes(textmap, processor_size, image_size):
111109
bboxes = get_detected_boxes(textmap)
112-
bboxes = [rescale_bbox(bbox, processor_size, image_size) for bbox in bboxes]
110+
for bbox in bboxes:
111+
bbox.rescale(processor_size, image_size)
113112
bboxes = clean_contained_boxes(bboxes)
114113
return bboxes
115114

@@ -122,3 +121,14 @@ def draw_bboxes_on_image(bboxes, image):
122121

123122
return image
124123

124+
125+
def draw_polys_on_image(corners, image):
126+
draw = ImageDraw.Draw(image)
127+
128+
for poly in corners:
129+
poly = [(p[0], p[1]) for p in poly]
130+
draw.polygon(poly, outline='red', width=1)
131+
132+
return image
133+
134+

surya/postprocessing/util.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,22 @@ def rescale_bbox(bbox, processor_size, image_size):
2323
new_bbox[1] = int(new_bbox[1] * height_scaler)
2424
new_bbox[2] = int(new_bbox[2] * width_scaler)
2525
new_bbox[3] = int(new_bbox[3] * height_scaler)
26-
return new_bbox
26+
return new_bbox
27+
28+
29+
def rescale_point(point, processor_size, image_size):
30+
# Point is in x, y format
31+
page_width, page_height = processor_size
32+
33+
img_width, img_height = image_size
34+
width_scaler = img_width / page_width
35+
height_scaler = img_height / page_height
36+
37+
new_point = copy.deepcopy(point)
38+
new_point[0] = int(new_point[0] * width_scaler)
39+
new_point[1] = int(new_point[1] * height_scaler)
40+
return new_point
41+
42+
43+
def rescale_points(points, processor_size, image_size):
44+
return [rescale_point(point, processor_size, image_size) for point in points]

surya/schema.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import copy
2+
from typing import List, Tuple
3+
4+
from pydantic import BaseModel, field_validator, computed_field
5+
6+
7+
class PolygonBox(BaseModel):
8+
corners: List[List[float]]
9+
10+
@field_validator('corners')
11+
@classmethod
12+
def check_elements(cls, v: List[List[float]]) -> List[List[float]]:
13+
if len(v) != 4:
14+
raise ValueError('corner must have 4 elements')
15+
16+
for corner in v:
17+
if len(corner) != 2:
18+
raise ValueError('corner must have 2 elements')
19+
return v
20+
21+
@property
22+
def height(self):
23+
return self.corners[1][1] - self.corners[0][1]
24+
25+
@property
26+
def width(self):
27+
return self.corners[1][0] - self.corners[0][0]
28+
29+
@property
30+
def area(self):
31+
return self.width * self.height
32+
33+
@computed_field
34+
@property
35+
def bbox(self) -> List[float]:
36+
box = [self.corners[0][0], self.corners[0][1], self.corners[1][0], self.corners[2][1]]
37+
if box[0] > box[2]:
38+
box[0], box[2] = box[2], box[0]
39+
if box[1] > box[3]:
40+
box[1], box[3] = box[3], box[1]
41+
return box
42+
43+
44+
def rescale(self, processor_size, image_size):
45+
# Point is in x, y format
46+
page_width, page_height = processor_size
47+
48+
img_width, img_height = image_size
49+
width_scaler = img_width / page_width
50+
height_scaler = img_height / page_height
51+
52+
new_corners = copy.deepcopy(self.corners)
53+
for corner in new_corners:
54+
corner[0] = int(corner[0] * width_scaler)
55+
corner[1] = int(corner[1] * height_scaler)
56+
self.corners = new_corners
57+
58+
59+
60+
class Bbox(BaseModel):
61+
bbox: List[float]
62+
63+
@field_validator('bbox')
64+
@classmethod
65+
def check_4_elements(cls, v: List[float]) -> List[float]:
66+
if len(v) != 4:
67+
raise ValueError('bbox must have 4 elements')
68+
return v
69+
70+
@property
71+
def height(self):
72+
return self.bbox[3] - self.bbox[1]
73+
74+
@property
75+
def width(self):
76+
return self.bbox[2] - self.bbox[0]
77+
78+
@property
79+
def area(self):
80+
return self.width * self.height

0 commit comments

Comments
 (0)