@@ -101,13 +101,13 @@ def parse_args():
101
101
default = os .path .join (current_dir , 'checkpoint.pth.tar' ), type = str , metavar = 'PATH' ,
102
102
help = 'path to checkpoint (default: None)' )
103
103
parser .add_argument ('--early_path' ,
104
- default = os .path .join (current_dir , 'checkpoints' , 'checkpoint_earlystop.pth.tar' ), type = str , metavar = 'PATH' ,
104
+ default = os .path .join (current_dir , 'checkpoints' , 'checkpoint_earlystop.pth.tar' ), type = str , metavar = 'PATH' ,
105
105
help = 'path to checkpoint (default: None)' )
106
106
parser .add_argument ('--exp' , type = str ,
107
107
default = current_dir , help = 'path to exp folder' )
108
108
parser .add_argument ('--optimizer' , type = str , metavar = 'OPTIM' ,
109
109
choices = ['Adam' , 'SGD' ], default = 'Adam' , help = 'optimizer_choice (default: Adam)' )
110
- parser .add_argument ('--patience' , type = int , default = 10 , help = 'Earlystopping patience' )
110
+ parser .add_argument ('--patience' , type = int , default = 1 , help = 'Earlystopping patience' )
111
111
parser .add_argument ('--semi_ratio' , type = float , default = 1 , help = 'ratio of the labeled samples' )
112
112
return parser .parse_args (args = [])
113
113
@@ -236,8 +236,10 @@ def sampling_echograms_full(args):
236
236
samplers_train = torch .load (os .path .join (path_to_echograms , 'sampler3_tr.pt' ))
237
237
supervised_count = int (len (samplers_train [0 ]) * args .semi_ratio )
238
238
samplers_supervised = []
239
+ samplers_tr_rest = []
239
240
for samplers in samplers_train :
240
241
samplers_supervised .append (samplers [:supervised_count ])
242
+ samplers_tr_rest .append (samplers [supervised_count :])
241
243
242
244
augmentation = CombineFunctions ([add_noise_img , flip_x_axis_img ])
243
245
data_transform = CombineFunctions ([remove_nan_inf_img , db_with_limits_img ])
@@ -248,7 +250,12 @@ def sampling_echograms_full(args):
248
250
augmentation_function = augmentation ,
249
251
data_transform_function = data_transform )
250
252
251
- return dataset_semi
253
+ dataset_tr_rest = DatasetImg (
254
+ samplers_tr_rest ,
255
+ args .sampler_probs ,
256
+ augmentation_function = augmentation ,
257
+ data_transform_function = data_transform )
258
+ return dataset_semi , dataset_tr_rest
252
259
253
260
def sampling_echograms_for_s3vm (args ):
254
261
path_to_echograms = paths .path_to_echograms ()
@@ -372,7 +379,7 @@ def main(args):
372
379
########################################'''
373
380
374
381
print ('Sample echograms.' )
375
- dataset_semi = sampling_echograms_full (args )
382
+ dataset_semi , dataset_tr_rest = sampling_echograms_full (args )
376
383
377
384
378
385
dataloader_semi = torch .utils .data .DataLoader (dataset_semi ,
@@ -382,6 +389,14 @@ def main(args):
382
389
drop_last = False ,
383
390
pin_memory = True )
384
391
392
+ dataloader_tr_rest = torch .utils .data .DataLoader (dataset_tr_rest ,
393
+ shuffle = False ,
394
+ batch_size = args .batch ,
395
+ num_workers = args .workers ,
396
+ drop_last = False ,
397
+ pin_memory = True )
398
+
399
+
385
400
dataset_test_bal , dataset_test_unbal = sampling_echograms_test (args )
386
401
dataloader_test_bal = torch .utils .data .DataLoader (dataset_test_bal ,
387
402
shuffle = False ,
@@ -566,11 +581,20 @@ def main(args):
566
581
with open (os .path .join (args .exp , 'train_anno_full_%s.pickle' % percentage ), "wb" ) as f :
567
582
pickle .dump (train_anno , f )
568
583
569
- features_train_unanno , input_tensors_train_unanno , labels_train_unanno = compute_features (
570
- dataloader_bg , model , len (dataset_bg_full ), device = device , args = args )
584
+ features_train_unanno , input_tensors_train_unanno , labels_train_unanno = compute_features (dataloader_tr_rest ,
585
+ model ,
586
+ len (dataset_tr_rest ),
587
+ device = device ,
588
+ args = args )
571
589
train_unanno = [features_train_unanno , labels_train_unanno ]
572
- with open (os .path .join (args .exp , 'train_bg_ %s.pickle' % percentage ), "wb" ) as f :
590
+ with open (os .path .join (args .exp , 'train_unanno_full_ %s.pickle' % percentage ), "wb" ) as f :
573
591
pickle .dump (train_unanno , f )
592
+
593
+ features_train_bg , input_tensors_train_bg , labels_train_bg = compute_features (
594
+ dataloader_bg , model , len (dataset_bg_full ), device = device , args = args )
595
+ train_bg = [features_train_bg , labels_train_bg ]
596
+ with open (os .path .join (args .exp , 'train_bg_%s.pickle' % percentage ), "wb" ) as f :
597
+ pickle .dump (train_bg , f )
574
598
'''
575
599
TESTSET
576
600
'''
0 commit comments