Skip to content

Commit 506df8a

Browse files
committed
update
1 parent ade0a06 commit 506df8a

14 files changed

+241
-76
lines changed

deepcluster/earlystopping.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import List
2+
import copy
3+
import operator
4+
from enum import Enum, auto
5+
import numpy as np
6+
7+
from torch.nn import Module
8+
9+
10+
class StopVariable(Enum):
11+
LOSS = auto()
12+
ACCURACY = auto()
13+
NONE = auto()
14+
15+
16+
class Best(Enum):
17+
RANKED = auto()
18+
ALL = auto()
19+
20+
# , StopVariable.LOSS
21+
stopping_args = dict(
22+
stop_varnames=[StopVariable.ACCURACY],
23+
patience=100, max_epochs=1000, remember=Best.RANKED)
24+
25+
26+
class EarlyStopping:
27+
def __init__(
28+
self, model: Module, stop_varnames: List[StopVariable],
29+
patience: int = 100, max_epochs: int = 1000, remember: Best = Best.RANKED):
30+
self.model = model
31+
self.comp_ops = []
32+
self.stop_vars = []
33+
self.best_vals = []
34+
for stop_varname in stop_varnames:
35+
if stop_varname is StopVariable.LOSS:
36+
self.stop_vars.append('loss')
37+
self.comp_ops.append(operator.le)
38+
self.best_vals.append(np.inf)
39+
elif stop_varname is StopVariable.ACCURACY:
40+
self.stop_vars.append('acc')
41+
self.comp_ops.append(operator.ge)
42+
self.best_vals.append(-np.inf)
43+
self.remember = remember
44+
self.remembered_vals = copy.copy(self.best_vals)
45+
self.max_patience = patience
46+
self.patience = self.max_patience
47+
self.max_epochs = max_epochs
48+
self.best_epoch = None
49+
self.best_state = None
50+
51+
def check(self, values: List[np.floating], epoch: int) -> bool:
52+
checks = [self.comp_ops[i](val, self.best_vals[i])
53+
for i, val in enumerate(values)]
54+
if any(checks):
55+
self.best_vals = np.choose(checks, [self.best_vals, values])
56+
self.patience = self.max_patience
57+
58+
comp_remembered = [
59+
self.comp_ops[i](val, self.remembered_vals[i])
60+
for i, val in enumerate(values)]
61+
if self.remember is Best.ALL:
62+
if all(comp_remembered):
63+
self.best_epoch = epoch
64+
self.remembered_vals = copy.copy(values)
65+
self.best_state = {
66+
key: value.cpu() for key, value
67+
in self.model.state_dict().items()}
68+
elif self.remember is Best.RANKED:
69+
for i, comp in enumerate(comp_remembered):
70+
if comp:
71+
if not(self.remembered_vals[i] == values[i]):
72+
self.best_epoch = epoch
73+
self.remembered_vals = copy.copy(values)
74+
self.best_state = {
75+
key: value.cpu() for key, value
76+
in self.model.state_dict().items()}
77+
break
78+
else:
79+
break
80+
else:
81+
self.patience -= 1
82+
return self.patience == 0

not_in_use/.DS_Store

8 KB
Binary file not shown.

semi/01p/main_echogram_semi_3classes.py renamed to not_in_use/semi_01p/main_echogram_semi_3classes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import matplotlib.pyplot as plt
2626

2727
current_dir = os.getcwd()
28-
sys.path.append(os.path.join(current_dir, '..', '..', 'deepcluster'))
28+
sys.path.append(os.path.join(current_dir, '../../semi', '..', 'deepcluster'))
2929

3030
import paths
3131
import clustering

semi/20p/main_echogram_semi_3classes.py renamed to not_in_use/semi_20p/main_echogram_semi_3classes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import matplotlib.pyplot as plt
2626

2727
current_dir = os.getcwd()
28-
sys.path.append(os.path.join(current_dir, '..', '..', 'deepcluster'))
28+
sys.path.append(os.path.join(current_dir, '../../semi', '..', 'deepcluster'))
2929

3030
import paths
3131
import clustering

semi/.DS_Store

0 Bytes
Binary file not shown.

semi/02p/main_echogram_semi_3classes.py

+157-74
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from batch.data_transform_functions.db_with_limits import db_with_limits_img
4747
from batch.combine_functions import CombineFunctions
4848
from classifier_linearSVC import SimpleClassifier
49+
from earlystopping import EarlyStopping, stopping_args
4950

5051
def parse_args():
5152
current_dir = os.getcwd()
@@ -357,6 +358,75 @@ def sampling_echograms_test(args):
357358

358359
return dataset_test_bal, dataset_test_unbal
359360

361+
def produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster):
362+
model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
363+
model.cluster_layer = None
364+
model.category_layer = None
365+
366+
print('TEST set: Cluster the features')
367+
features_te_bal, input_tensors_te_bal, labels_te_bal = compute_features(dataloader_test_bal, model, len(dataset_test_bal),
368+
device=device, args=args)
369+
clustering_loss_te_bal, pca_features_te_bal = deepcluster.cluster(features_te_bal, verbose=args.verbose)
370+
371+
mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end
372+
mlp.append(nn.ReLU(inplace=True).to(device))
373+
model.classifier = nn.Sequential(*mlp)
374+
model.classifier.to(device)
375+
376+
# nan_location_bal = np.isnan(pca_features_te_bal)
377+
# inf_location_bal = np.isinf(pca_features_te_bal)
378+
# if (not np.allclose(nan_location_bal, 0)) or (not np.allclose(inf_location_bal, 0)):
379+
# print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_bal), ' Inf count: ',
380+
# np.sum(inf_location_bal))
381+
# print('Skip epoch ', epoch)
382+
# torch.save(pca_features_te_bal, 'te_pca_NaN_%d_bal.pth.tar' % epoch)
383+
# torch.save(features_te_bal, 'te_feature_NaN_%d_bal.pth.tar' % epoch)
384+
# continue
385+
386+
# save patches per epochs
387+
cp_epoch_out_bal = [features_te_bal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_bal,
388+
labels_te_bal]
389+
with open(os.path.join(args.exp, 'bal', 'features', 'cp_epoch_%d_te_bal.pickle' % epoch), "wb") as f:
390+
pickle.dump(cp_epoch_out_bal, f)
391+
with open(os.path.join(args.exp, 'bal', 'pca_features', 'pca_epoch_%d_te_bal.pickle' % epoch), "wb") as f:
392+
pickle.dump(pca_features_te_bal, f)
393+
return 0
394+
395+
def produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster):
396+
model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
397+
model.cluster_layer = None
398+
model.category_layer = None
399+
400+
print('TEST set: Cluster the features')
401+
features_te_unbal, input_tensors_te_unbal, labels_te_unbal = compute_features(dataloader_test_unbal, model, len(dataset_test_unbal),
402+
device=device, args=args)
403+
clustering_loss_te_unbal, pca_features_te_unbal = deepcluster.cluster(features_te_unbal, verbose=args.verbose)
404+
405+
mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end
406+
mlp.append(nn.ReLU(inplace=True).to(device))
407+
model.classifier = nn.Sequential(*mlp)
408+
model.classifier.to(device)
409+
410+
# nan_location_unbal = np.isnan(pca_features_te_unbal)
411+
# inf_location_unbal = np.isinf(pca_features_te_unbal)
412+
# if (not np.allclose(nan_location_unbal, 0)) or (not np.allclose(inf_location_unbal, 0)):
413+
# print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_unbal), ' Inf count: ',
414+
# np.sum(inf_location_unbal))
415+
# print('Skip epoch ', epoch)
416+
# torch.save(pca_features_te_unbal, 'te_pca_NaN_%d_unbal.pth.tar' % epoch)
417+
# torch.save(features_te_unbal, 'te_feature_NaN_%d_unbal.pth.tar' % epoch)
418+
# continue
419+
420+
# save patches per epochs
421+
cp_epoch_out_unbal = [features_te_unbal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_unbal,
422+
labels_te_unbal]
423+
424+
with open(os.path.join(args.exp, 'unbal', 'features', 'cp_epoch_%d_te_unbal.pickle' % epoch), "wb") as f:
425+
pickle.dump(cp_epoch_out_unbal, f)
426+
with open(os.path.join(args.exp, 'unbal', 'pca_features', 'pca_epoch_%d_te_unbal.pickle' % epoch), "wb") as f:
427+
pickle.dump(pca_features_te_unbal, f)
428+
return 0
429+
360430
def main(args):
361431
# fix random seeds
362432
torch.manual_seed(args.seed)
@@ -418,6 +488,16 @@ def main(args):
418488
model.category_layer = model.category_layer.double()
419489
model.category_layer.to(device)
420490

491+
'''
492+
############################
493+
############################
494+
# EarlyStopping (test_accuracy_bal, 100)
495+
############################
496+
############################
497+
'''
498+
early_stopping = EarlyStopping(model, **stopping_args)
499+
stop_vars = []
500+
421501
if args.optimizer is 'Adam':
422502
print('Adam optimizer: conv')
423503
optimizer_category = torch.optim.Adam(
@@ -531,7 +611,7 @@ def main(args):
531611
MAIN TRAINING
532612
#######################
533613
#######################'''
534-
for epoch in range(args.start_epoch, args.epochs):
614+
for epoch in range(args.start_epoch, early_stopping.max_epochs):
535615
end = time.time()
536616
print('##################### Start training at Epoch %d ################'% epoch)
537617
model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
@@ -693,92 +773,95 @@ def main(args):
693773
with open(os.path.join(args.exp, 'loss_collect.pickle'), "wb") as f:
694774
pickle.dump(loss_collect, f)
695775

696-
'''
697-
############################
698-
############################
699-
# PSEUDO-LABEL GEN: Test set (balanced UA)
700-
############################
701-
############################
702-
'''
703-
model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
704-
model.cluster_layer = None
705-
model.category_layer = None
706-
707-
print('TEST set: Cluster the features')
708-
features_te_bal, input_tensors_te_bal, labels_te_bal = compute_features(dataloader_test_bal, model, len(dataset_test_bal),
709-
device=device, args=args)
710-
clustering_loss_te_bal, pca_features_te_bal = deepcluster.cluster(features_te_bal, verbose=args.verbose)
711-
712-
mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end
713-
mlp.append(nn.ReLU(inplace=True).to(device))
714-
model.classifier = nn.Sequential(*mlp)
715-
model.classifier.to(device)
716-
717-
nan_location_bal = np.isnan(pca_features_te_bal)
718-
inf_location_bal = np.isinf(pca_features_te_bal)
719-
if (not np.allclose(nan_location_bal, 0)) or (not np.allclose(inf_location_bal, 0)):
720-
print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_bal), ' Inf count: ',
721-
np.sum(inf_location_bal))
722-
print('Skip epoch ', epoch)
723-
torch.save(pca_features_te_bal, 'te_pca_NaN_%d_bal.pth.tar' % epoch)
724-
torch.save(features_te_bal, 'te_feature_NaN_%d_bal.pth.tar' % epoch)
725-
continue
726-
727-
# save patches per epochs
728-
cp_epoch_out_bal = [features_te_bal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_bal,
729-
labels_te_bal]
776+
if (epoch % args.save_epoch == 0):
777+
out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster)
778+
out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster)
730779

780+
'''EarlyStopping'''
781+
if early_stopping.check(loss_collect[7], epoch):
782+
break
731783

732-
if (epoch % args.save_epoch == 0):
733-
with open(os.path.join(args.exp, 'bal', 'features', 'cp_epoch_%d_te_bal.pickle' % epoch), "wb") as f:
734-
pickle.dump(cp_epoch_out_bal, f)
735-
with open(os.path.join(args.exp, 'bal', 'pca_features', 'pca_epoch_%d_te_bal.pickle' % epoch), "wb") as f:
736-
pickle.dump(pca_features_te_bal, f)
784+
out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster)
785+
out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args,
786+
deepcluster)
737787

738788

739789
'''
740790
############################
741791
############################
742-
# PSEUDO-LABEL GEN: Test set (Unbalanced UA)
792+
# PSEUDO-LABEL GEN: Test set (balanced UA)
743793
############################
744794
############################
745795
'''
746-
model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
747-
model.cluster_layer = None
748-
model.category_layer = None
749-
750-
print('TEST set: Cluster the features')
751-
features_te_unbal, input_tensors_te_unbal, labels_te_unbal = compute_features(dataloader_test_unbal, model, len(dataset_test_unbal),
752-
device=device, args=args)
753-
clustering_loss_te_unbal, pca_features_te_unbal = deepcluster.cluster(features_te_unbal, verbose=args.verbose)
754-
755-
mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end
756-
mlp.append(nn.ReLU(inplace=True).to(device))
757-
model.classifier = nn.Sequential(*mlp)
758-
model.classifier.to(device)
759-
760-
nan_location_unbal = np.isnan(pca_features_te_unbal)
761-
inf_location_unbal = np.isinf(pca_features_te_unbal)
762-
if (not np.allclose(nan_location_unbal, 0)) or (not np.allclose(inf_location_unbal, 0)):
763-
print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_unbal), ' Inf count: ',
764-
np.sum(inf_location_unbal))
765-
print('Skip epoch ', epoch)
766-
torch.save(pca_features_te_unbal, 'te_pca_NaN_%d_unbal.pth.tar' % epoch)
767-
torch.save(features_te_unbal, 'te_feature_NaN_%d_unbal.pth.tar' % epoch)
768-
continue
769-
770-
# save patches per epochs
771-
cp_epoch_out_unbal = [features_te_unbal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_unbal,
772-
labels_te_unbal]
773796

797+
def produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster):
798+
model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
799+
model.cluster_layer = None
800+
model.category_layer = None
774801

775-
if (epoch % args.save_epoch == 0):
776-
with open(os.path.join(args.exp, 'unbal', 'features', 'cp_epoch_%d_te_unbal.pickle' % epoch), "wb") as f:
777-
pickle.dump(cp_epoch_out_unbal, f)
778-
with open(os.path.join(args.exp, 'unbal', 'pca_features', 'pca_epoch_%d_te_unbal.pickle' % epoch), "wb") as f:
779-
pickle.dump(pca_features_te_unbal, f)
780-
802+
print('TEST set: Cluster the features')
803+
features_te_bal, input_tensors_te_bal, labels_te_bal = compute_features(dataloader_test_bal, model, len(dataset_test_bal),
804+
device=device, args=args)
805+
clustering_loss_te_bal, pca_features_te_bal = deepcluster.cluster(features_te_bal, verbose=args.verbose)
806+
807+
mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end
808+
mlp.append(nn.ReLU(inplace=True).to(device))
809+
model.classifier = nn.Sequential(*mlp)
810+
model.classifier.to(device)
811+
812+
# nan_location_bal = np.isnan(pca_features_te_bal)
813+
# inf_location_bal = np.isinf(pca_features_te_bal)
814+
# if (not np.allclose(nan_location_bal, 0)) or (not np.allclose(inf_location_bal, 0)):
815+
# print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_bal), ' Inf count: ',
816+
# np.sum(inf_location_bal))
817+
# print('Skip epoch ', epoch)
818+
# torch.save(pca_features_te_bal, 'te_pca_NaN_%d_bal.pth.tar' % epoch)
819+
# torch.save(features_te_bal, 'te_feature_NaN_%d_bal.pth.tar' % epoch)
820+
# continue
821+
822+
# save patches per epochs
823+
cp_epoch_out_bal = [features_te_bal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_bal,
824+
labels_te_bal]
825+
with open(os.path.join(args.exp, 'bal', 'features', 'cp_epoch_%d_te_bal.pickle' % epoch), "wb") as f:
826+
pickle.dump(cp_epoch_out_bal, f)
827+
with open(os.path.join(args.exp, 'bal', 'pca_features', 'pca_epoch_%d_te_bal.pickle' % epoch), "wb") as f:
828+
pickle.dump(pca_features_te_bal, f)
829+
return 0
830+
831+
def produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster):
832+
model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
833+
model.cluster_layer = None
834+
model.category_layer = None
781835

836+
print('TEST set: Cluster the features')
837+
features_te_unbal, input_tensors_te_unbal, labels_te_unbal = compute_features(dataloader_test_unbal, model, len(dataset_test_unbal),
838+
device=device, args=args)
839+
clustering_loss_te_unbal, pca_features_te_unbal = deepcluster.cluster(features_te_unbal, verbose=args.verbose)
840+
841+
mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end
842+
mlp.append(nn.ReLU(inplace=True).to(device))
843+
model.classifier = nn.Sequential(*mlp)
844+
model.classifier.to(device)
845+
846+
# nan_location_unbal = np.isnan(pca_features_te_unbal)
847+
# inf_location_unbal = np.isinf(pca_features_te_unbal)
848+
# if (not np.allclose(nan_location_unbal, 0)) or (not np.allclose(inf_location_unbal, 0)):
849+
# print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_unbal), ' Inf count: ',
850+
# np.sum(inf_location_unbal))
851+
# print('Skip epoch ', epoch)
852+
# torch.save(pca_features_te_unbal, 'te_pca_NaN_%d_unbal.pth.tar' % epoch)
853+
# torch.save(features_te_unbal, 'te_feature_NaN_%d_unbal.pth.tar' % epoch)
854+
# continue
855+
856+
# save patches per epochs
857+
cp_epoch_out_unbal = [features_te_unbal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_unbal,
858+
labels_te_unbal]
859+
860+
with open(os.path.join(args.exp, 'unbal', 'features', 'cp_epoch_%d_te_unbal.pickle' % epoch), "wb") as f:
861+
pickle.dump(cp_epoch_out_unbal, f)
862+
with open(os.path.join(args.exp, 'unbal', 'pca_features', 'pca_epoch_%d_te_unbal.pickle' % epoch), "wb") as f:
863+
pickle.dump(pca_features_te_unbal, f)
864+
return 0
782865

783866
if __name__ == '__main__':
784867
args = parse_args()

sup/.DS_Store

0 Bytes
Binary file not shown.

supclust/.DS_Store

0 Bytes
Binary file not shown.

unbalsemi/.DS_Store

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)