1
+ import math
2
+ from typing import List , Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from surya .schema import LayoutResult
9
+
10
+ SLICES_TYPE = Tuple [List [Image .Image ], List [Tuple [int , int , int ]]]
11
+
12
+
13
+ class ImageSlicer :
14
+ merge_tolerance = .05
15
+
16
+ def __init__ (self , slice_min_dims , max_slices = 4 ):
17
+ self .slice_min_dims = slice_min_dims
18
+ self .max_slices = max_slices
19
+
20
+ def slice (self , images : List [Image .Image ]) -> SLICES_TYPE :
21
+ all_slices = []
22
+ all_positions = []
23
+
24
+ for idx , image in enumerate (images ):
25
+ if (image .size [0 ] > self .slice_min_dims ["width" ] or
26
+ image .size [1 ] > self .slice_min_dims ["height" ]):
27
+ img_slices , positions = self ._slice_image (image , idx )
28
+ all_slices .extend (img_slices )
29
+ all_positions .extend (positions )
30
+ else :
31
+ all_slices .append (image )
32
+ all_positions .append ((idx , 0 , 0 ))
33
+
34
+ return all_slices , all_positions
35
+
36
+ def slice_count (self , image : Image .Image ) -> int :
37
+ width , height = image .size
38
+ if width > height :
39
+ slice_size = self ._calculate_slice_size (width , "width" )
40
+ return math .ceil (width / slice_size )
41
+ else :
42
+ slice_size = self ._calculate_slice_size (height , "height" )
43
+ return math .ceil (height / slice_size )
44
+
45
+ def _calculate_slice_size (self , dimension : int , dim_type : str ) -> int :
46
+ min_size = self .slice_min_dims [dim_type ]
47
+ return max (min_size , (dimension // self .max_slices + 1 ))
48
+
49
+ def _slice_image (self , image : Image .Image , idx : int ) -> SLICES_TYPE :
50
+ width , height = image .size
51
+ slices = []
52
+ positions = []
53
+
54
+ if width > height :
55
+ slice_size = self ._calculate_slice_size (width , "width" )
56
+ for i , x in enumerate (range (0 , width , slice_size )):
57
+ slice_end = min (x + slice_size , width )
58
+ slices .append (image .crop ((x , 0 , slice_end , height )))
59
+ positions .append ((idx , i , 0 ))
60
+ else :
61
+ slice_size = self ._calculate_slice_size (height , "height" )
62
+ for i , y in enumerate (range (0 , height , slice_size )):
63
+ slice_end = min (y + slice_size , height )
64
+ slices .append (image .crop ((0 , y , width , slice_end )))
65
+ positions .append ((idx , 0 , i ))
66
+
67
+ return slices , positions
68
+
69
+ def join (self , results : List [LayoutResult ], tile_positions : List [Tuple [int , int , int ]]) -> List [LayoutResult ]:
70
+ new_results = []
71
+ current_result = None
72
+ for idx , (result , tile_position ) in enumerate (zip (results , tile_positions )):
73
+ image_idx , tile_x , tile_y = tile_position
74
+ if idx == 0 or image_idx != tile_positions [idx - 1 ][0 ]:
75
+ if current_result is not None :
76
+ new_results .append (current_result )
77
+ current_result = result
78
+ else :
79
+ merge_dir = "width" if tile_x > 0 else "height"
80
+ current_result = self .merge_results (current_result , result , merge_dir = merge_dir )
81
+ if current_result is not None :
82
+ new_results .append (current_result )
83
+ return new_results
84
+
85
+
86
+ def merge_results (self , res1 : LayoutResult , res2 : LayoutResult , merge_dir = "width" ) -> LayoutResult :
87
+ new_image_bbox = res1 .image_bbox .copy ()
88
+ to_remove_idxs = set ()
89
+ if merge_dir == "width" :
90
+ new_image_bbox [2 ] += res2 .image_bbox [2 ]
91
+ max_position = max ([box .position for box in res1 .bboxes ])
92
+ for i , box2 in enumerate (res2 .bboxes ):
93
+ box2 .shift (x_shift = res1 .image_bbox [2 ])
94
+ box2 .position += max_position
95
+ for j , box1 in enumerate (res1 .bboxes ):
96
+ if all ([
97
+ box1 .intersection_area (box2 , x_margin = .1 ) > self .merge_tolerance ,
98
+ (
99
+ box1 .y_overlap (box2 , y_margin = .1 ) > box1 .height // 2 or
100
+ box2 .y_overlap (box1 , y_margin = .1 ) > box2 .height // 2
101
+ ),
102
+ box1 .label == box2 .label
103
+ ]):
104
+ box1 .merge (box2 )
105
+ to_remove_idxs .add (i )
106
+
107
+ elif merge_dir == "height" :
108
+ new_image_bbox [3 ] += res2 .image_bbox [3 ]
109
+ max_position = max ([box .position for box in res1 .bboxes ])
110
+ for i , box2 in enumerate (res2 .bboxes ):
111
+ box2 .shift (y_shift = res1 .image_bbox [3 ])
112
+ box2 .position += max_position
113
+ for j , box1 in enumerate (res1 .bboxes ):
114
+ if all ([
115
+ box1 .intersection_area (box2 , y_margin = .1 ) > self .merge_tolerance ,
116
+ (
117
+ box1 .x_overlap (box2 , x_margin = .1 ) > box1 .width // 2 or
118
+ box2 .x_overlap (box1 , x_margin = .1 ) > box2 .width // 2
119
+ ),
120
+ box1 .label == box2 .label
121
+ ]):
122
+ box1 .merge (box2 )
123
+ to_remove_idxs .add (i )
124
+
125
+ new_result = LayoutResult (
126
+ image_bbox = new_image_bbox ,
127
+ bboxes = res1 .bboxes + [b for i , b in enumerate (res2 .bboxes ) if i not in to_remove_idxs ]
128
+ )
129
+ return new_result
0 commit comments