Skip to content

Commit 00a41b0

Browse files
author
NC Cullen
authored
Merge pull request #26 from benwu232/master
When using cuda, this method copies the model to GPU before the epoch…
2 parents 5245e4a + 983c37d commit 00a41b0

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchsample/modules/module_trainer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,9 @@ def fit_loader(self,
494494
initializers = InitializerModule(self._initializers)
495495
initializers(self.model)
496496

497+
if cuda_device > -1:
498+
self.model.cuda(cuda_device)
499+
497500
# enter context-manager for progress bar
498501
with TQDM() as pbar:
499502
# create callbacks
@@ -547,7 +550,6 @@ def fit_loader(self,
547550
input_batch = [ins.cuda(cuda_device) for ins in input_batch]
548551
if has_target:
549552
target_batch = [targs.cuda(cuda_device) for targs in target_batch]
550-
self.model.cuda(cuda_device)
551553

552554
# apply input, target, and input+target transforms if necessary
553555
if self._has_input_transform:

0 commit comments

Comments
 (0)