Skip to content

Commit

Permalink
Enhance DataLoader configurations in BlueSkyNonGeoDataModule by addin…
Browse files Browse the repository at this point in the history
…g pin_memory, persistent_workers, and prefetch_factor parameters for improved performance during training, validation, and testing.
  • Loading branch information
valhassan committed Nov 28, 2024
1 parent aadc64a commit 60196fc
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions geo_deep_learning/datamodules/imagery_NonGeoDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,29 @@ def train_dataloader(self) -> DataLoader[Any]:

return DataLoader(self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=True)
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
prefetch_factor=2,
shuffle=True)

def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False)
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
prefetch_factor=2,
shuffle=False)

def test_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False)
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
prefetch_factor=2,
shuffle=False)

def _manage_bands(self, image: torch.Tensor) -> torch.Tensor:
"""
Expand Down

0 comments on commit 60196fc

Please sign in to comment.