Skip to content

Commit

Permalink
Refactor SegFormer weight loading with utility function and device ma…
Browse files Browse the repository at this point in the history
…pping
  • Loading branch information
valhassan committed Mar 7, 2025
1 parent eb5c740 commit 7247a27
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions geo_deep_learning/tasks_with_models/segmentation_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchmetrics.wrappers import ClasswiseWrapper
from models.segmentation.segformer import SegFormerSegmentationModel
from tools.script_model import ScriptModel
from tools.utils import denormalization
from tools.utils import denormalization, load_weights_from_checkpoint
from tools.visualization import visualize_prediction

class SegmentationSegformer(LightningModule):
Expand Down Expand Up @@ -49,9 +49,12 @@ def __init__(self,
self.num_classes = num_classes
self.model = SegFormerSegmentationModel(encoder, in_channels, weights, freeze_layers, self.num_classes)
if weights_from_checkpoint_path:
print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}")
checkpoint = torch.load(weights_from_checkpoint_path)
self.load_state_dict(checkpoint['state_dict'])
map_location = self.device
load_parts = kwargs.get('load_parts', None)
load_weights_from_checkpoint(self,
weights_from_checkpoint_path,
load_parts=load_parts,
map_location=map_location,)
self.loss = loss
num_classes = num_classes + 1 if num_classes == 1 else num_classes
self.iou_metric = MeanIoU(num_classes=num_classes,
Expand Down

0 comments on commit 7247a27

Please sign in to comment.