diff --git a/geo_deep_learning/tasks_with_models/segmentation_unetplus.py b/geo_deep_learning/tasks_with_models/segmentation_unetplus.py index 22cd4c38..0647e72b 100644 --- a/geo_deep_learning/tasks_with_models/segmentation_unetplus.py +++ b/geo_deep_learning/tasks_with_models/segmentation_unetplus.py @@ -50,8 +50,7 @@ def __init__(self, if weights_from_checkpoint_path: print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}") checkpoint = torch.load(weights_from_checkpoint_path) - self.load_state_dict(checkpoint['state_dict']) - + self.load_state_dict(checkpoint['state_dict']) self.loss = loss num_classes = num_classes + 1 if num_classes == 1 else num_classes self.iou_metric = MeanIoU(num_classes=num_classes,