From 8a0a1c4654272e9821e568577cb87c8de84d3d5a Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Thu, 29 Jul 2021 15:25:54 +0200 Subject: [PATCH 1/8] added callbacks --- train.py | 19 ++++- utils/callbacks.py | 183 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 utils/callbacks.py diff --git a/train.py b/train.py index 3f5b5ed1195b..a0acd0a0fc7f 100644 --- a/train.py +++ b/train.py @@ -43,6 +43,8 @@ from utils.metrics import fitness from utils.loggers import Loggers +from utils import callbacks + LOGGER = logging.getLogger(__name__) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) @@ -52,6 +54,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, + callbacks ): save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ @@ -330,6 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots) + callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots) # end batch ------------------------------------------------------------------------------------------------ @@ -340,6 +344,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if RANK in [-1, 0]: # mAP loggers.on_train_epoch_end(epoch) + callbacks.on_train_epoch_end(epoch) + ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not noval or final_epoch: # Calculate mAP @@ -361,6 +367,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if fi > best_fitness: best_fitness = fi loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi) + callbacks.on_val_end(mloss, results, lr, epoch, best_fitness, fi) # Save model if (not nosave) or (final_epoch and not evolve): # if save @@ -378,6 +385,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary torch.save(ckpt, best) del ckpt loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi) + callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- @@ -401,6 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if f.exists(): strip_optimizer(f) # strip optimizers loggers.on_train_end(last, best, plots) + callbacks.on_train_end(last, best, plots) torch.cuda.empty_cache() return results @@ -446,7 +455,11 @@ def parse_opt(known=False): return opt -def main(opt): +def main(opt, callback_handler = None): + + # Define new hook handler if one is not passed in + if not callback_handler: callback_handler = callbacks.Callbacks() + set_logging(RANK) if RANK in [-1, 0]: print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) @@ -482,7 +495,7 @@ def main(opt): # Train if not opt.evolve: - train(opt.hyp, opt, device) + train(opt.hyp, opt, device, callback_handler) if WORLD_SIZE > 1 and RANK == 0: _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] @@ -562,7 +575,7 @@ def main(opt): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation - results = train(hyp.copy(), opt, device) + results = train(hyp.copy(), opt, device, callback_handler) # Write mutation results print_mutation(hyp.copy(), results, yaml_file, opt.bucket) diff --git a/utils/callbacks.py b/utils/callbacks.py new file mode 100644 index 000000000000..d02f114021b3 --- /dev/null +++ b/utils/callbacks.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python + +class Callbacks: + """" + Handles all registered callbacks for YOLOv5 Hooks + """ + + _callbacks = { + 'on_pretrain_routine_start' :[], + 'on_pretrain_routine_end' :[], + + 'on_train_start' :[], + 'on_train_end': [], + 'on_train_epoch_start': [], + 'on_train_epoch_end': [], + 'on_train_batch_start': [], + 'on_train_batch_end': [], + + 'on_val_start' :[], + 'on_val_end': [], + 'on_val_epoch_start': [], + 'on_val_epoch_end': [], + 'on_val_batch_start': [], + 'on_val_batch_end': [], + + + 'on_model_save': [], + 'optimizer_step': [], + 'on_before_zero_grad': [], + 'teardown': [], + } + + def __init__(self): + return + + def regsiterAction(self, hook, name, callback): + """ + Register a new action to a callback hook + + Args: + action The callback hook name to register the action to + name The name of the action + callback The callback to fire + + Returns: + (Bool) The success state + """ + if hook in self._callbacks: + self._callbacks[hook].append({'name': name, 'callback': callback}) + return True + else: + return False + + def getRegisteredActions(self, hook=None): + """" + Returns all the registered actions by callback hook + + Args: + hook The name of the hook to check, defaults to all + """ + if hook: + return self._callbacks[hook] + else: + return self._callbacks + + def fireCallbacks(self, register, *args): + """ + Loop throughs the registered actions and fires all callbacks + """ + for logger in register: + logger['callback'](*args) + + + def on_pretrain_routine_start(self, *args): + """ + Fires all registered callbacks at the start of each pretraining routine + """ + self.fireCallbacks(self._callbacks['on_pretrain_routine_start'], *args) + + def on_pretrain_routine_end(self, *args): + """ + Fires all registered callbacks at the end of each pretraining routine + """ + self.fireCallbacks(self._callbacks['on_pretrain_routine_end'], *args) + + def on_train_start(self, *args): + """ + Fires all registered callbacks at the start of each training + """ + self.fireCallbacks(self._callbacks['on_train_start'], *args) + + def on_train_end(self, *args): + """ + Fires all registered callbacks at the end of training + """ + self.fireCallbacks(self._callbacks['on_train_end'], *args) + + def on_train_epoch_start(self, *args): + """ + Fires all registered callbacks at the start of each training epoch + """ + self.fireCallbacks(self._callbacks['on_train_epoch_start'], *args) + + def on_train_epoch_end(self, *args): + """ + Fires all registered callbacks at the end of each training epoch + """ + self.fireCallbacks(self._callbacks['on_train_epoch_end'], *args) + + + def on_train_batch_start(self, *args): + """ + Fires all registered callbacks at the start of each training batch + """ + self.fireCallbacks(self._callbacks['on_train_batch_start'], *args) + + def on_train_batch_end(self, *args): + """ + Fires all registered callbacks at the end of each training batch + """ + self.fireCallbacks(self._callbacks['on_train_batch_end'], *args) + + def on_val_start(self, *args): + """ + Fires all registered callbacks at the start of the validation + """ + self.fireCallbacks(self._callbacks['on_val_start'], *args) + + def on_val_end(self, *args): + """ + Fires all registered callbacks at the end of the validation + """ + self.fireCallbacks(self._callbacks['on_val_end'], *args) + + def on_val_epoch_start(self, *args): + """ + Fires all registered callbacks at the start of each validation epoch + """ + self.fireCallbacks(self._callbacks['on_val_epoch_start'], *args) + + def on_val_epoch_end(self, *args): + """ + Fires all registered callbacks at the end of each validation epoch + """ + self.fireCallbacks(self._callbacks['on_val_epoch_end'], *args) + + def on_val_batch_start(self, *args): + """ + Fires all registered callbacks at the start of each validation batch + """ + self.fireCallbacks(self._callbacks['on_val_batch_start'], *args) + + def on_val_batch_end(self, *args): + """ + Fires all registered callbacks at the end of each validation batch + """ + self.fireCallbacks(self._callbacks['on_val_batch_end'], *args) + + def on_model_save(self, *args): + """ + Fires all registered callbacks after each model save + """ + self.fireCallbacks(self._callbacks['on_model_save'], *args) + + def optimizer_step(self, *args): + """ + Fires all registered callbacks on each optimizer step + """ + self.fireCallbacks(self._callbacks['optimizer_step'], *args) + + def on_before_zero_grad(self, *args): + """ + Fires all registered callbacks before zero grad + """ + self.fireCallbacks(self._callbacks['on_before_zero_grad'], *args) + + def teardown(self, *args): + """ + Fires all registered callbacks before teardown + """ + self.fireCallbacks(self._callbacks['teardown'], *args) + + From 83fb93fb5d4c2169734e43b03e3e1eb89396d7f5 Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Tue, 3 Aug 2021 10:44:55 +0200 Subject: [PATCH 2/8] added back callback to main --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index dbaf16e9a158..1c84e51a8a56 100644 --- a/train.py +++ b/train.py @@ -457,7 +457,7 @@ def parse_opt(known=False): return opt -def main(opt): +def main(opt, callback_handler): # Checks set_logging(RANK) if RANK in [-1, 0]: From b4434b774f13482b6715ce3bfbda076c352cda6f Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Tue, 3 Aug 2021 11:37:51 +0200 Subject: [PATCH 3/8] added save_dir to callback output --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 1c84e51a8a56..d25b37f059b0 100644 --- a/train.py +++ b/train.py @@ -386,7 +386,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if best_fitness == fi: torch.save(ckpt, best) del ckpt - callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) + callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi, save_dir) # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- From ef0760a988895939534df54f5528337872ba71ed Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Thu, 9 Sep 2021 17:29:54 +0200 Subject: [PATCH 4/8] merged in upstream --- train.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/train.py b/train.py index 626b7c32f821..7c05849dee00 100644 --- a/train.py +++ b/train.py @@ -383,9 +383,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if best_fitness == fi: torch.save(ckpt, best) del ckpt -<<<<<<< HEAD - callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi, save_dir) -======= callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) # Stop Single-GPU @@ -401,7 +398,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # with torch_distributed_zero_first(RANK): # if stop: # break # must break all DDP ranks ->>>>>>> 2d9411dbb85ae63b8ca9913726844767898eb021 # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- @@ -473,11 +469,7 @@ def parse_opt(known=False): return opt -<<<<<<< HEAD -def main(opt, callback_handler): -======= def main(opt, callbacks=Callbacks()): ->>>>>>> 2d9411dbb85ae63b8ca9913726844767898eb021 # Checks set_logging(RANK) if RANK in [-1, 0]: @@ -516,11 +508,7 @@ def main(opt, callbacks=Callbacks()): # Train if not opt.evolve: -<<<<<<< HEAD - train(opt.hyp, opt, device, callback_handler) -======= train(opt.hyp, opt, device, callbacks) ->>>>>>> 2d9411dbb85ae63b8ca9913726844767898eb021 if WORLD_SIZE > 1 and RANK == 0: _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] @@ -600,11 +588,7 @@ def main(opt, callbacks=Callbacks()): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation -<<<<<<< HEAD - results = train(hyp.copy(), opt, device, callback_handler) -======= results = train(hyp.copy(), opt, device, callbacks) ->>>>>>> 2d9411dbb85ae63b8ca9913726844767898eb021 # Write mutation results print_mutation(results, hyp.copy(), save_dir, opt.bucket) From 46bb613eb63e11ac2dabdabdd9243f3ce6f4d067 Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Sat, 11 Sep 2021 09:17:29 +0200 Subject: [PATCH 5/8] removed ghost code --- train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/train.py b/train.py index 2ce6e5e4c7ef..e5410eeeba9f 100644 --- a/train.py +++ b/train.py @@ -47,8 +47,6 @@ from utils.loggers import Loggers from utils.callbacks import Callbacks -from utils import callbacks - LOGGER = logging.getLogger(__name__) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) From 0a379d3c3322415dde4c9ff5f9ce25641e9b5d78 Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Thu, 16 Sep 2021 11:30:24 +0200 Subject: [PATCH 6/8] added url check --- export.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/export.py b/export.py index ea7f1ebd0b1f..220ba8f37312 100644 --- a/export.py +++ b/export.py @@ -26,6 +26,7 @@ import sys import time from pathlib import Path +import urllib import torch import torch.nn as nn @@ -244,6 +245,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' include = [x.lower() for x in include] tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports imgsz *= 2 if len(imgsz) == 1 else 1 # expand + + # Fix path if URL [DEV] Copied from general.check_file[299] + if weights.startswith(('http:/', 'https:/')): + url = str(Path(weights)).replace(':/', '://') # Pathlib turns :// -> :/ + weights = Path(urllib.parse.unquote(weights)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth + file = Path(weights) # Load PyTorch model From 43d43e6e102821ad2cfce1e79cc42999d5f2bc97 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 16 Sep 2021 13:20:04 +0200 Subject: [PATCH 7/8] Add url2file() --- utils/general.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/utils/general.py b/utils/general.py index 7a80b2ea81bc..dc9a10fe8617 100755 --- a/utils/general.py +++ b/utils/general.py @@ -360,6 +360,13 @@ def check_dataset(data, autodownload=True): return data # dictionary +def url2file(url): + # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt + url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ + file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth + return file + + def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1): # Multi-threaded file download and unzip function, used in data.yaml for autodownload def download_one(url, dir): From 0833edd74d14814f338728a31e62ca76187b3cc1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 16 Sep 2021 13:20:51 +0200 Subject: [PATCH 8/8] Update file-only --- export.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/export.py b/export.py index 220ba8f37312..4ec3c3e0c711 100644 --- a/export.py +++ b/export.py @@ -26,7 +26,6 @@ import sys import time from pathlib import Path -import urllib import torch import torch.nn as nn @@ -42,7 +41,7 @@ from models.yolo import Detect from utils.activations import SiLU from utils.datasets import LoadImages -from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging +from utils.general import colorstr, check_dataset, check_img_size, check_requirements, file_size, set_logging, url2file from utils.torch_utils import select_device @@ -245,13 +244,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' include = [x.lower() for x in include] tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports imgsz *= 2 if len(imgsz) == 1 else 1 # expand - - # Fix path if URL [DEV] Copied from general.check_file[299] - if weights.startswith(('http:/', 'https:/')): - url = str(Path(weights)).replace(':/', '://') # Pathlib turns :// -> :/ - weights = Path(urllib.parse.unquote(weights)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth - - file = Path(weights) + file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # Load PyTorch model device = select_device(device)