-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain_detector.py
28 lines (22 loc) · 1014 Bytes
/
train_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from argparse import ArgumentParser
from os import remove
from ultralytics import YOLO
if __name__ == '__main__':
parser = ArgumentParser(description = 'DeepArUco++ detector trainer.')
parser.add_argument('source_dir', help = 'where to find source images')
parser.add_argument('run_name', help = 'directory of the resulting model')
parser.add_argument('--model', '-m', help = 'base model to train', default='yolov8m')
args = parser.parse_args()
with open(f'{args.run_name}.yaml', 'w') as f:
f.write(f'path: \'{args.source_dir}\'\n')
f.write('train: \'train/images\'\n')
f.write('val: \'valid/images\'\n')
f.write('names:\n 0: \'marker\'')
model = YOLO(f'models/{args.model}.pt')
model.train(data = f'{args.run_name}.yaml',
rect = True, iou = 0.5,
batch = -1,
epochs = 1000, patience = 10,
cache = True,
name = args.run_name)
remove(f'{args.run_name}.yaml')