Skip to content

Commit 8ea8f6a

Browse files
author
Luan Pham
committed
lint
1 parent eb385d5 commit 8ea8f6a

File tree

13 files changed

+264
-205
lines changed

13 files changed

+264
-205
lines changed

.pre-commit-config.yaml

+43-34
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,44 @@
11
repos:
2-
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v2.3.0
4-
hooks:
5-
- id: check-yaml
6-
- id: end-of-file-fixer
7-
- id: trailing-whitespace
8-
- id: detect-aws-credentials
9-
- id: detect-private-key
10-
- id: end-of-file-fixer
11-
- id: check-added-large-files
12-
- repo: https://github.com/ambv/black
13-
rev: 21.5b0
14-
hooks:
15-
- id: black
16-
language_version: python3.9
17-
- repo: https://github.com/pycqa/isort
18-
rev: 5.8.0
19-
hooks:
20-
- id: isort
21-
args: [--profile, black]
22-
- repo: https://gitlab.com/pycqa/flake8
23-
rev: 3.9.1
24-
hooks:
25-
- id: flake8
26-
- repo: https://github.com/myint/autoflake
27-
rev: v1.4
28-
hooks:
29-
- id: autoflake
30-
args: [
31-
"--in-place",
32-
"--remove-unused-variables",
33-
"--remove-all-unused-imports",
34-
"--exclude=tests/*",
35-
]
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.3.0
4+
hooks:
5+
- id: check-yaml
6+
- id: check-toml
7+
- id: end-of-file-fixer
8+
- id: trailing-whitespace
9+
- id: detect-aws-credentials
10+
args: ["--allow-missing-credentials"]
11+
- id: detect-private-key
12+
- id: end-of-file-fixer
13+
- id: check-added-large-files
14+
- repo: https://github.com/ambv/black
15+
rev: 22.3.0
16+
hooks:
17+
- id: black
18+
language_version: python3.7
19+
- repo: https://github.com/pycqa/isort
20+
rev: 5.8.0
21+
hooks:
22+
- id: isort
23+
args: ["--profile", "black"]
24+
- repo: https://gitlab.com/pycqa/flake8
25+
rev: 3.9.1
26+
hooks:
27+
- id: flake8
28+
- repo: https://github.com/myint/autoflake
29+
rev: v1.4
30+
hooks:
31+
- id: autoflake
32+
args:
33+
[
34+
"--in-place",
35+
"--remove-unused-variables",
36+
"--remove-all-unused-imports",
37+
"--ignore-init-module-imports",
38+
"--exclude=tests/*",
39+
]
40+
- repo: https://github.com/pre-commit/mirrors-prettier
41+
rev: v2.7.1
42+
hooks:
43+
- id: prettier
44+
types_or: [markdown, yaml]

README.md

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,29 @@
44

55
Project Board: https://github.com/users/phamquiluan/projects/3/views/1
66

7-
87
# Prepare data
98

109
1. Download data from here and put to `data` dir: https://drive.google.com/drive/folders/1J_z-laBlG14Fps81FVrUJUjesdND_JTx?usp=sharing
1110
2. The image dir path `$PWD/data/images`
1211

13-
1412
# Dev guide
1513

1614
1. cd into your dir
15+
1716
```bash
1817
# for example
1918
cd text_detection
2019
```
2120

2221
2. create venv
22+
2323
```bash
2424
python3.9 -m venv env
2525
. env/bin/activate
2626
```
2727

2828
3. install requirements
29+
2930
```bash
3031
pip install -r requirements.txt
3132
```
@@ -34,8 +35,6 @@ pip install -r requirements.txt
3435

3536
on your Python file.
3637

37-
38-
3938
# Docker guide
4039

4140
```

main.py

+84-45
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,26 @@ def timeit_wrapper(*args, **kwargs):
1919
result = func(*args, **kwargs)
2020
end_time = time.perf_counter()
2121
total_time = end_time - start_time
22-
print(f'Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
22+
print(f"Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds")
2323
return result
24+
2425
return timeit_wrapper
2526

27+
2628
# load env var
2729
load_dotenv()
28-
TABLE_DETECTION_PORT=env["TABLE_DETECTION_PORT"]
29-
TABLE_RECOGNITION_PORT=env["TABLE_RECOGNITION_PORT"]
30-
TEXT_DETECTION_PORT=env["TEXT_DETECTION_PORT"]
31-
TEXT_RECOGNITION_PORT=env["TEXT_RECOGNITION_PORT"]
30+
TABLE_DETECTION_PORT = env["TABLE_DETECTION_PORT"]
31+
TABLE_RECOGNITION_PORT = env["TABLE_RECOGNITION_PORT"]
32+
TEXT_DETECTION_PORT = env["TEXT_DETECTION_PORT"]
33+
TEXT_RECOGNITION_PORT = env["TEXT_RECOGNITION_PORT"]
3234

3335

3436
class Box(BaseModel):
35-
name : str = "box"
36-
xmin : int
37-
xmax : int
38-
ymin : int
39-
ymax : int
37+
name: str = "box"
38+
xmin: int
39+
xmax: int
40+
ymin: int
41+
ymax: int
4042

4143
@property
4244
def width(self):
@@ -61,38 +63,42 @@ def get_intersection(self, box):
6163
return (xmax - xmin) * (ymax - ymin)
6264
return 0
6365

66+
6467
class Text(Box):
65-
name : str = "text"
66-
ocr : str = ""
68+
name: str = "text"
69+
ocr: str = ""
70+
6771

6872
class Cell(Box):
69-
name : str = "cell"
70-
texts : List[Text] = []
73+
name: str = "cell"
74+
texts: List[Text] = []
7175

7276
def is_valid(self):
7377
return self.width > CELL_MIN_WIDTH and self.height > CELL_MIN_HEIGHT
7478

7579

7680
class Table(Box):
77-
name : str = "table"
78-
cells : List[Cell] = []
81+
name: str = "table"
82+
cells: List[Cell] = []
83+
7984

80-
def read_tables_from_list(input_list : List[Dict]) -> List[Table]:
85+
def read_tables_from_list(input_list: List[Dict]) -> List[Table]:
8186
tables = []
8287
for item in input_list:
8388
if item["name"] != "table":
8489
continue
8590
new_table = Table(
86-
xmin=item["xmin"],
87-
ymin=item["ymin"],
88-
xmax=item["xmax"],
89-
ymax=item["ymax"]
91+
xmin=item["xmin"], ymin=item["ymin"], xmax=item["xmax"], ymax=item["ymax"]
9092
)
91-
new_table.cells = [Cell(xmin=i["xmin"], ymin=i["ymin"], xmax=i["xmax"], ymax=i["ymax"]) for i in item["cells"]]
93+
new_table.cells = [
94+
Cell(xmin=i["xmin"], ymin=i["ymin"], xmax=i["xmax"], ymax=i["ymax"])
95+
for i in item["cells"]
96+
]
9297
tables.append(new_table)
9398
return tables
9499

95-
def read_texts_from_list(input_list : List[Dict]) -> List[Text]:
100+
101+
def read_texts_from_list(input_list: List[Dict]) -> List[Text]:
96102
texts = []
97103
for item in input_list:
98104
if item["name"] != "text":
@@ -103,34 +109,44 @@ def read_texts_from_list(input_list : List[Dict]) -> List[Text]:
103109
ymin=item["ymin"],
104110
xmax=item["xmax"],
105111
ymax=item["ymax"],
106-
ocr=item["ocr"]
112+
ocr=item["ocr"],
107113
)
108114
)
109115
return texts
110116

117+
111118
@timeit
112119
def get_table(image_path):
113120
image_name = os.path.basename(image_path)
114121
url = f"http://localhost:{TABLE_RECOGNITION_PORT}/ai/infer"
115-
files=[
116-
('file',(image_name,open(image_path,'rb'), mimetypes.guess_type(image_path)[0]))
122+
files = [
123+
(
124+
"file",
125+
(image_name, open(image_path, "rb"), mimetypes.guess_type(image_path)[0]),
126+
)
117127
]
118128
response = requests.request("POST", url, files=files)
119129
return response.json()
120130

131+
121132
@timeit
122133
def get_ocr(image_path):
123134
image_name = os.path.basename(image_path)
124135
url = f"http://localhost:{TEXT_RECOGNITION_PORT}/ai/infer"
125-
files=[
126-
('file',(image_name,open(image_path,'rb'), mimetypes.guess_type(image_path)[0]))
136+
files = [
137+
(
138+
"file",
139+
(image_name, open(image_path, "rb"), mimetypes.guess_type(image_path)[0]),
140+
)
127141
]
128142
response = requests.request("POST", url, files=files)
129143
return response.json()
130144

145+
131146
def get_random_color():
132147
return tuple((np.random.random(3) * 153 + 102).astype(np.uint8).tolist())
133148

149+
134150
def show(img, name="disp", width=1000):
135151
"""
136152
name: name of window, should be name of img
@@ -143,25 +159,43 @@ def show(img, name="disp", width=1000):
143159
cv2.destroyAllWindows()
144160

145161

146-
def draw(image, table_list : List[Table]):
162+
def draw(image, table_list: List[Table]):
147163
vis_image = image.copy()
148164

149165
# draw cell
150166
for table in table_list:
151167
for cell in table.cells:
152-
cv2.rectangle(vis_image, (cell.xmin, cell.ymin), (cell.xmax, cell.ymax), get_random_color(), -1)
168+
cv2.rectangle(
169+
vis_image,
170+
(cell.xmin, cell.ymin),
171+
(cell.xmax, cell.ymax),
172+
get_random_color(),
173+
-1,
174+
)
153175

154176
vis_image = vis_image // 2 + image // 2
155177

156178
# draw table
157179
for table in table_list:
158-
cv2.rectangle(vis_image, (table.xmin, table.ymin), (table.xmax, table.ymax), (0, 0, 255), 4)
180+
cv2.rectangle(
181+
vis_image,
182+
(table.xmin, table.ymin),
183+
(table.xmax, table.ymax),
184+
(0, 0, 255),
185+
4,
186+
)
159187

160188
# draw text
161189
for table in table_list:
162190
for cell in table.cells:
163191
for text in cell.texts:
164-
cv2.rectangle(vis_image, (text.xmin, text.ymin), (text.xmax, text.ymax), (255, 0, 0), 2)
192+
cv2.rectangle(
193+
vis_image,
194+
(text.xmin, text.ymin),
195+
(text.xmax, text.ymax),
196+
(255, 0, 0),
197+
2,
198+
)
165199

166200
# put text
167201
for table in table_list:
@@ -172,43 +206,48 @@ def draw(image, table_list : List[Table]):
172206
text.ocr,
173207
(text.xmin, text.ymin),
174208
cv2.FONT_HERSHEY_SIMPLEX,
175-
0.5, (0, 255, 0), 1)
209+
0.5,
210+
(0, 255, 0),
211+
1,
212+
)
176213
return vis_image
177214

178215

179-
def draw_text(image, text_list : List[Text]):
216+
def draw_text(image, text_list: List[Text]):
180217
for text in text_list:
181-
cv2.rectangle(image, (text.xmin, text.ymin), (text.xmax, text.ymax), (255, 0, 0), 2)
218+
cv2.rectangle(
219+
image, (text.xmin, text.ymin), (text.xmax, text.ymax), (255, 0, 0), 2
220+
)
182221
return image
183222

184223

185-
def merge_text_table(tables : List[Table], texts : List[Text]):
224+
def merge_text_table(tables: List[Table], texts: List[Text]):
186225
for table in tables:
187226
in_table_texts = [t for t in texts if t.get_intersection(table) > 0]
188227

189228
for cell in table.cells:
190-
cell.texts = [t for t in in_table_texts if t.get_intersection(cell) / t.area > 0.4]
229+
cell.texts = [
230+
t for t in in_table_texts if t.get_intersection(cell) / t.area > 0.4
231+
]
232+
191233

192234
@timeit
193235
def main():
194-
image_path = '/home/luan/research/Go5-Project/sample.jpg'
236+
image_path = "/home/luan/research/Go5-Project/sample.jpg"
195237

196238
# read table
197-
output : List = get_table(image_path)
198-
tables : List[Table] = read_tables_from_list(output)
239+
output: List = get_table(image_path)
240+
tables: List[Table] = read_tables_from_list(output)
199241

200242
# read text
201-
output : List = get_ocr(image_path)
202-
texts : List[Text] = read_texts_from_list(output)
243+
output: List = get_ocr(image_path)
244+
texts: List[Text] = read_texts_from_list(output)
203245

204246
merge_text_table(tables, texts)
205247

206248
image = cv2.imread(image_path)
207249
show(draw(image, tables))
208250

209251

210-
211-
212-
213252
if __name__ == "__main__":
214253
main()

table_detection/Install.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
##Go into your prefer directory and Clone the repo
44

5-
65
```
76
mkdir Project
87
cd Project/
@@ -30,12 +29,14 @@ pip install mmcv==0.4.3
3029
```
3130

3231
##Clone the repo
32+
3333
```
3434
cd Project/Go5-Project/table_detection
3535
git clone https://github.com/DevashishPrasad/CascadeTabNet.git
3636
```
3737

3838
##Download the Pretrained Model
39+
3940
```
4041
gdown "https://drive.google.com/u/0/uc?id=1-mVr4UBicFk3mjUz5tsVPjQ4jzRtiT7V&export=download"
4142
```

0 commit comments

Comments
 (0)