Skip to content

Commit 07a715c

Browse files
committed
Fix slice pad
1 parent 260fc0c commit 07a715c

8 files changed

+165
-16
lines changed

benchmark/recognition.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from surya.ocr import run_ocr, run_recognition
88
from surya.postprocessing.text import draw_text_on_image
99
from surya.settings import settings
10+
from surya.languages import CODE_TO_LANGUAGE, is_arabic
11+
import arabic_reshaper
1012
import os
1113
import datasets
1214
import json
@@ -46,9 +48,12 @@ def main():
4648

4749
image_scores = defaultdict(list)
4850
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
51+
if any(is_arabic(l) for l in lang):
52+
ref_text = [arabic_reshaper.reshape(t) for t in ref_text]
53+
pred["text_lines"] = [arabic_reshaper.reshape(t) for t in pred["text_lines"]]
4954
image_score = overlap_score(pred["text_lines"], ref_text)
5055
for l in lang:
51-
image_scores[l].append(image_score)
56+
image_scores[CODE_TO_LANGUAGE[l]].append(image_score)
5257

5358
image_avgs = {l: sum(scores) / len(scores) for l, scores in image_scores.items()}
5459
print(image_avgs)

ocr_text.py

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from surya.ocr import run_ocr
1111
from surya.postprocessing.text import draw_text_on_image
1212
from surya.settings import settings
13+
from surya.languages import LANGUAGE_TO_CODE, CODE_TO_LANGUAGE
1314
import os
1415

1516

@@ -23,7 +24,14 @@ def main():
2324
parser.add_argument("--lang", type=str, help="Language to use for OCR. Comma separate for multiple.", default="en")
2425
args = parser.parse_args()
2526

27+
# Split and validate language codes
2628
langs = args.lang.split(",")
29+
for i in range(len(langs)):
30+
if langs[i] in LANGUAGE_TO_CODE:
31+
langs[i] = LANGUAGE_TO_CODE[langs[i]]
32+
if langs[i] not in CODE_TO_LANGUAGE:
33+
raise ValueError(f"Language code {langs[i]} not found.")
34+
2735
det_processor = load_detection_processor()
2836
det_model = load_detection_model()
2937

poetry.lock

+15-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pymupdf = "^1.23.8"
3434
snakeviz = "^2.2.0"
3535
datasets = "^2.16.1"
3636
rapidfuzz = "^3.6.1"
37+
arabic-reshaper = "^3.0.0"
3738

3839
[tool.poetry.scripts]
3940
surya_detect = "detect_text:main"

surya/input/processing.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,31 @@ def slice_polys_from_image(image: Image.Image, polys):
7575

7676

7777
def slice_and_pad_poly(image: Image.Image, coordinates):
78-
coordinates = [(corner[0], corner[1]) for corner in coordinates]
79-
78+
# Create a mask for the polygon
8079
mask = Image.new('L', image.size, 0)
80+
81+
# coordinates must be in tuple form for PIL
82+
coordinates = [(corner[0], corner[1]) for corner in coordinates]
8183
ImageDraw.Draw(mask).polygon(coordinates, outline=1, fill=1)
82-
bbox = mask.getbbox()
83-
mask = mask.crop(bbox)
84-
cropped_image = image.crop(bbox)
85-
mask = mask.convert('1')
86-
rectangle = Image.new('RGB', cropped_image.size, 'white')
87-
rectangle.paste(cropped_image, (0, 0), mask)
84+
mask = np.array(mask)
85+
86+
# Extract the polygonal area from the image
87+
polygon_image = np.array(image)
88+
polygon_image[~mask] = 0
89+
polygon_image = Image.fromarray(polygon_image)
90+
91+
bbox_image = Image.new('L', image.size, 0)
92+
ImageDraw.Draw(bbox_image).polygon(coordinates, outline=1, fill=1)
93+
bbox = bbox_image.getbbox()
94+
95+
rectangle = Image.new('RGB', (bbox[2] - bbox[0], bbox[3] - bbox[1]), 'white')
96+
97+
# Paste the polygon into the rectangle
98+
polygon_center = (bbox[2] + bbox[0]) // 2, (bbox[3] + bbox[1]) // 2
99+
rectangle_center = rectangle.width // 2, rectangle.height // 2
100+
paste_position = (rectangle_center[0] - polygon_center[0] + bbox[0],
101+
rectangle_center[1] - polygon_center[1] + bbox[1])
102+
rectangle.paste(polygon_image.crop(bbox), paste_position)
88103

89104
return rectangle
105+

surya/languages.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
CODE_TO_LANGUAGE = {
2+
'af': 'Afrikaans',
3+
'am': 'Amharic',
4+
'ar': 'Arabic',
5+
'as': 'Assamese',
6+
'az': 'Azerbaijani',
7+
'be': 'Belarusian',
8+
'bg': 'Bulgarian',
9+
'bn': 'Bangla',
10+
'br': 'Breton',
11+
'bs': 'Bosnian',
12+
'ca': 'Catalan',
13+
'cs': 'Czech',
14+
'cy': 'Welsh',
15+
'da': 'Danish',
16+
'de': 'German',
17+
'el': 'Greek',
18+
'en': 'English',
19+
'eo': 'Esperanto',
20+
'es': 'Spanish',
21+
'et': 'Estonian',
22+
'eu': 'Basque',
23+
'fa': 'Persian',
24+
'fi': 'Finnish',
25+
'fr': 'French',
26+
'fy': 'Western Frisian',
27+
'ga': 'Irish',
28+
'gd': 'Scottish Gaelic',
29+
'gl': 'Galician',
30+
'gu': 'Gujarati',
31+
'ha': 'Hausa',
32+
'he': 'Hebrew',
33+
'hi': 'Hindi',
34+
'hr': 'Croatian',
35+
'hu': 'Hungarian',
36+
'hy': 'Armenian',
37+
'id': 'Indonesian',
38+
'is': 'Icelandic',
39+
'it': 'Italian',
40+
'ja': 'Japanese',
41+
'jv': 'Javanese',
42+
'ka': 'Georgian',
43+
'kk': 'Kazakh',
44+
'km': 'Khmer',
45+
'kn': 'Kannada',
46+
'ko': 'Korean',
47+
'ku': 'Kurdish',
48+
'ky': 'Kyrgyz',
49+
'la': 'Latin',
50+
'lo': 'Lao',
51+
'lt': 'Lithuanian',
52+
'lv': 'Latvian',
53+
'mg': 'Malagasy',
54+
'mk': 'Macedonian',
55+
'ml': 'Malayalam',
56+
'mn': 'Mongolian',
57+
'mr': 'Marathi',
58+
'ms': 'Malay',
59+
'my': 'Burmese',
60+
'ne': 'Nepali',
61+
'nl': 'Dutch',
62+
'no': 'Norwegian',
63+
'om': 'Oromo',
64+
'or': 'Odia',
65+
'pa': 'Punjabi',
66+
'pl': 'Polish',
67+
'ps': 'Pashto',
68+
'pt': 'Portuguese',
69+
'ro': 'Romanian',
70+
'ru': 'Russian',
71+
'sa': 'Sanskrit',
72+
'sd': 'Sindhi',
73+
'si': 'Sinhala',
74+
'sk': 'Slovak',
75+
'sl': 'Slovenian',
76+
'so': 'Somali',
77+
'sq': 'Albanian',
78+
'sr': 'Serbian',
79+
'su': 'Sundanese',
80+
'sv': 'Swedish',
81+
'sw': 'Swahili',
82+
'ta': 'Tamil',
83+
'te': 'Telugu',
84+
'th': 'Thai',
85+
'tl': 'Tagalog',
86+
'tr': 'Turkish',
87+
'ug': 'Uyghur',
88+
'uk': 'Ukrainian',
89+
'ur': 'Urdu',
90+
'uz': 'Uzbek',
91+
'vi': 'Vietnamese',
92+
'xh': 'Xhosa',
93+
'yi': 'Yiddish',
94+
'zh': 'Chinese'
95+
}
96+
97+
LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()}
98+
99+
100+
def is_arabic(lang_code):
101+
return lang_code in ["ar", "fa", "ps", "ug", "ur"]

surya/ocr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
5757
slice_map = []
5858
all_slices = []
5959
all_langs = []
60-
for idx, (image, det_pred, lang) in tqdm(enumerate(zip(images, det_predictions, langs)), desc="Slicing images"):
60+
for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)):
6161
slices = slice_polys_from_image(image, det_pred["polygons"])
6262
slice_map.append(len(slices))
6363
all_slices.extend(slices)
@@ -80,4 +80,4 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
8080
"language": lang
8181
})
8282

83-
return predictions_by_image
83+
return predictions_by_image

surya/recognition.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@ def get_batch_size():
1919

2020
def batch_recognition(images: List, languages: List[List[str]], model, processor):
2121
assert all([isinstance(image, Image.Image) for image in images])
22+
assert len(images) == len(languages)
2223
batch_size = get_batch_size()
2324

2425
images = [image.convert("RGB") for image in images]
25-
model_inputs = processor(text=[""] * len(languages), images=images, lang=languages)
2626

2727
output_text = []
28-
for i in tqdm(range(0, len(model_inputs["pixel_values"]), batch_size), desc="Recognizing Text"):
29-
batch_langs = model_inputs["langs"][i:i+batch_size]
30-
batch_pixel_values = model_inputs["pixel_values"][i:i+batch_size]
28+
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"):
29+
batch_langs = languages[i:i+batch_size]
30+
batch_images = images[i:i+batch_size]
31+
model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs)
32+
33+
batch_pixel_values = model_inputs["pixel_values"]
34+
batch_langs = model_inputs["langs"]
3135
batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
3236

3337
batch_langs = torch.from_numpy(np.array(batch_langs, dtype=np.int64)).to(model.device)

0 commit comments

Comments
 (0)