Skip to content

Commit 7cb659c

Browse files
committed
Update main_echogram_supervised_3classes.py
1 parent 4bf06cc commit 7cb659c

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

sup/100p/main_echogram_supervised_3classes.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ def parse_args():
101101
default=os.path.join(current_dir, 'checkpoint.pth.tar'), type=str, metavar='PATH',
102102
help='path to checkpoint (default: None)')
103103
parser.add_argument('--early_path',
104-
default= os.path.join(current_dir, 'checkpoints', 'checkpoint_earlystop.pth.tar'), type=str, metavar='PATH',
104+
default=os.path.join(current_dir, 'checkpoints', 'checkpoint_earlystop.pth.tar'), type=str, metavar='PATH',
105105
help='path to checkpoint (default: None)')
106106
parser.add_argument('--exp', type=str,
107107
default=current_dir, help='path to exp folder')
108108
parser.add_argument('--optimizer', type=str, metavar='OPTIM',
109109
choices=['Adam', 'SGD'], default='Adam', help='optimizer_choice (default: Adam)')
110-
parser.add_argument('--patience', type=int, default=10, help='Earlystopping patience')
110+
parser.add_argument('--patience', type=int, default=1, help='Earlystopping patience')
111111
parser.add_argument('--semi_ratio', type=float, default=1, help='ratio of the labeled samples')
112112
return parser.parse_args(args=[])
113113

@@ -236,8 +236,10 @@ def sampling_echograms_full(args):
236236
samplers_train = torch.load(os.path.join(path_to_echograms, 'sampler3_tr.pt'))
237237
supervised_count = int(len(samplers_train[0]) * args.semi_ratio)
238238
samplers_supervised = []
239+
samplers_tr_rest = []
239240
for samplers in samplers_train:
240241
samplers_supervised.append(samplers[:supervised_count])
242+
samplers_tr_rest.append(samplers[supervised_count:])
241243

242244
augmentation = CombineFunctions([add_noise_img, flip_x_axis_img])
243245
data_transform = CombineFunctions([remove_nan_inf_img, db_with_limits_img])
@@ -248,7 +250,12 @@ def sampling_echograms_full(args):
248250
augmentation_function=augmentation,
249251
data_transform_function=data_transform)
250252

251-
return dataset_semi
253+
dataset_tr_rest = DatasetImg(
254+
samplers_tr_rest,
255+
args.sampler_probs,
256+
augmentation_function=augmentation,
257+
data_transform_function=data_transform)
258+
return dataset_semi, dataset_tr_rest
252259

253260
def sampling_echograms_for_s3vm(args):
254261
path_to_echograms = paths.path_to_echograms()
@@ -372,7 +379,7 @@ def main(args):
372379
########################################'''
373380

374381
print('Sample echograms.')
375-
dataset_semi = sampling_echograms_full(args)
382+
dataset_semi, dataset_tr_rest = sampling_echograms_full(args)
376383

377384

378385
dataloader_semi = torch.utils.data.DataLoader(dataset_semi,
@@ -382,6 +389,14 @@ def main(args):
382389
drop_last=False,
383390
pin_memory=True)
384391

392+
dataloader_tr_rest = torch.utils.data.DataLoader(dataset_tr_rest,
393+
shuffle=False,
394+
batch_size=args.batch,
395+
num_workers=args.workers,
396+
drop_last=False,
397+
pin_memory=True)
398+
399+
385400
dataset_test_bal, dataset_test_unbal = sampling_echograms_test(args)
386401
dataloader_test_bal = torch.utils.data.DataLoader(dataset_test_bal,
387402
shuffle=False,
@@ -566,11 +581,20 @@ def main(args):
566581
with open(os.path.join(args.exp, 'train_anno_full_%s.pickle' % percentage), "wb") as f:
567582
pickle.dump(train_anno, f)
568583

569-
features_train_unanno, input_tensors_train_unanno, labels_train_unanno = compute_features(
570-
dataloader_bg, model, len(dataset_bg_full), device=device, args=args)
584+
features_train_unanno, input_tensors_train_unanno, labels_train_unanno = compute_features(dataloader_tr_rest,
585+
model,
586+
len(dataset_tr_rest),
587+
device=device,
588+
args=args)
571589
train_unanno = [features_train_unanno, labels_train_unanno]
572-
with open(os.path.join(args.exp, 'train_bg_%s.pickle' % percentage), "wb") as f:
590+
with open(os.path.join(args.exp, 'train_unanno_full_%s.pickle' % percentage), "wb") as f:
573591
pickle.dump(train_unanno, f)
592+
593+
features_train_bg, input_tensors_train_bg, labels_train_bg = compute_features(
594+
dataloader_bg, model, len(dataset_bg_full), device=device, args=args)
595+
train_bg = [features_train_bg, labels_train_bg]
596+
with open(os.path.join(args.exp, 'train_bg_%s.pickle' % percentage), "wb") as f:
597+
pickle.dump(train_bg, f)
574598
'''
575599
TESTSET
576600
'''

0 commit comments

Comments
 (0)