|
46 | 46 | from batch.data_transform_functions.db_with_limits import db_with_limits_img
|
47 | 47 | from batch.combine_functions import CombineFunctions
|
48 | 48 | from classifier_linearSVC import SimpleClassifier
|
| 49 | +from earlystopping import EarlyStopping, stopping_args |
49 | 50 |
|
50 | 51 | def parse_args():
|
51 | 52 | current_dir = os.getcwd()
|
@@ -357,6 +358,75 @@ def sampling_echograms_test(args):
|
357 | 358 |
|
358 | 359 | return dataset_test_bal, dataset_test_unbal
|
359 | 360 |
|
| 361 | +def produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster): |
| 362 | + model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1] |
| 363 | + model.cluster_layer = None |
| 364 | + model.category_layer = None |
| 365 | + |
| 366 | + print('TEST set: Cluster the features') |
| 367 | + features_te_bal, input_tensors_te_bal, labels_te_bal = compute_features(dataloader_test_bal, model, len(dataset_test_bal), |
| 368 | + device=device, args=args) |
| 369 | + clustering_loss_te_bal, pca_features_te_bal = deepcluster.cluster(features_te_bal, verbose=args.verbose) |
| 370 | + |
| 371 | + mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end |
| 372 | + mlp.append(nn.ReLU(inplace=True).to(device)) |
| 373 | + model.classifier = nn.Sequential(*mlp) |
| 374 | + model.classifier.to(device) |
| 375 | + |
| 376 | + # nan_location_bal = np.isnan(pca_features_te_bal) |
| 377 | + # inf_location_bal = np.isinf(pca_features_te_bal) |
| 378 | + # if (not np.allclose(nan_location_bal, 0)) or (not np.allclose(inf_location_bal, 0)): |
| 379 | + # print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_bal), ' Inf count: ', |
| 380 | + # np.sum(inf_location_bal)) |
| 381 | + # print('Skip epoch ', epoch) |
| 382 | + # torch.save(pca_features_te_bal, 'te_pca_NaN_%d_bal.pth.tar' % epoch) |
| 383 | + # torch.save(features_te_bal, 'te_feature_NaN_%d_bal.pth.tar' % epoch) |
| 384 | + # continue |
| 385 | + |
| 386 | + # save patches per epochs |
| 387 | + cp_epoch_out_bal = [features_te_bal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_bal, |
| 388 | + labels_te_bal] |
| 389 | + with open(os.path.join(args.exp, 'bal', 'features', 'cp_epoch_%d_te_bal.pickle' % epoch), "wb") as f: |
| 390 | + pickle.dump(cp_epoch_out_bal, f) |
| 391 | + with open(os.path.join(args.exp, 'bal', 'pca_features', 'pca_epoch_%d_te_bal.pickle' % epoch), "wb") as f: |
| 392 | + pickle.dump(pca_features_te_bal, f) |
| 393 | + return 0 |
| 394 | + |
| 395 | +def produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster): |
| 396 | + model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1] |
| 397 | + model.cluster_layer = None |
| 398 | + model.category_layer = None |
| 399 | + |
| 400 | + print('TEST set: Cluster the features') |
| 401 | + features_te_unbal, input_tensors_te_unbal, labels_te_unbal = compute_features(dataloader_test_unbal, model, len(dataset_test_unbal), |
| 402 | + device=device, args=args) |
| 403 | + clustering_loss_te_unbal, pca_features_te_unbal = deepcluster.cluster(features_te_unbal, verbose=args.verbose) |
| 404 | + |
| 405 | + mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end |
| 406 | + mlp.append(nn.ReLU(inplace=True).to(device)) |
| 407 | + model.classifier = nn.Sequential(*mlp) |
| 408 | + model.classifier.to(device) |
| 409 | + |
| 410 | + # nan_location_unbal = np.isnan(pca_features_te_unbal) |
| 411 | + # inf_location_unbal = np.isinf(pca_features_te_unbal) |
| 412 | + # if (not np.allclose(nan_location_unbal, 0)) or (not np.allclose(inf_location_unbal, 0)): |
| 413 | + # print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_unbal), ' Inf count: ', |
| 414 | + # np.sum(inf_location_unbal)) |
| 415 | + # print('Skip epoch ', epoch) |
| 416 | + # torch.save(pca_features_te_unbal, 'te_pca_NaN_%d_unbal.pth.tar' % epoch) |
| 417 | + # torch.save(features_te_unbal, 'te_feature_NaN_%d_unbal.pth.tar' % epoch) |
| 418 | + # continue |
| 419 | + |
| 420 | + # save patches per epochs |
| 421 | + cp_epoch_out_unbal = [features_te_unbal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_unbal, |
| 422 | + labels_te_unbal] |
| 423 | + |
| 424 | + with open(os.path.join(args.exp, 'unbal', 'features', 'cp_epoch_%d_te_unbal.pickle' % epoch), "wb") as f: |
| 425 | + pickle.dump(cp_epoch_out_unbal, f) |
| 426 | + with open(os.path.join(args.exp, 'unbal', 'pca_features', 'pca_epoch_%d_te_unbal.pickle' % epoch), "wb") as f: |
| 427 | + pickle.dump(pca_features_te_unbal, f) |
| 428 | + return 0 |
| 429 | + |
360 | 430 | def main(args):
|
361 | 431 | # fix random seeds
|
362 | 432 | torch.manual_seed(args.seed)
|
@@ -418,6 +488,16 @@ def main(args):
|
418 | 488 | model.category_layer = model.category_layer.double()
|
419 | 489 | model.category_layer.to(device)
|
420 | 490 |
|
| 491 | + ''' |
| 492 | + ############################ |
| 493 | + ############################ |
| 494 | + # EarlyStopping (test_accuracy_bal, 100) |
| 495 | + ############################ |
| 496 | + ############################ |
| 497 | + ''' |
| 498 | + early_stopping = EarlyStopping(model, **stopping_args) |
| 499 | + stop_vars = [] |
| 500 | + |
421 | 501 | if args.optimizer is 'Adam':
|
422 | 502 | print('Adam optimizer: conv')
|
423 | 503 | optimizer_category = torch.optim.Adam(
|
@@ -531,7 +611,7 @@ def main(args):
|
531 | 611 | MAIN TRAINING
|
532 | 612 | #######################
|
533 | 613 | #######################'''
|
534 |
| - for epoch in range(args.start_epoch, args.epochs): |
| 614 | + for epoch in range(args.start_epoch, early_stopping.max_epochs): |
535 | 615 | end = time.time()
|
536 | 616 | print('##################### Start training at Epoch %d ################'% epoch)
|
537 | 617 | model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1]
|
@@ -693,92 +773,95 @@ def main(args):
|
693 | 773 | with open(os.path.join(args.exp, 'loss_collect.pickle'), "wb") as f:
|
694 | 774 | pickle.dump(loss_collect, f)
|
695 | 775 |
|
696 |
| - ''' |
697 |
| - ############################ |
698 |
| - ############################ |
699 |
| - # PSEUDO-LABEL GEN: Test set (balanced UA) |
700 |
| - ############################ |
701 |
| - ############################ |
702 |
| - ''' |
703 |
| - model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1] |
704 |
| - model.cluster_layer = None |
705 |
| - model.category_layer = None |
706 |
| - |
707 |
| - print('TEST set: Cluster the features') |
708 |
| - features_te_bal, input_tensors_te_bal, labels_te_bal = compute_features(dataloader_test_bal, model, len(dataset_test_bal), |
709 |
| - device=device, args=args) |
710 |
| - clustering_loss_te_bal, pca_features_te_bal = deepcluster.cluster(features_te_bal, verbose=args.verbose) |
711 |
| - |
712 |
| - mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end |
713 |
| - mlp.append(nn.ReLU(inplace=True).to(device)) |
714 |
| - model.classifier = nn.Sequential(*mlp) |
715 |
| - model.classifier.to(device) |
716 |
| - |
717 |
| - nan_location_bal = np.isnan(pca_features_te_bal) |
718 |
| - inf_location_bal = np.isinf(pca_features_te_bal) |
719 |
| - if (not np.allclose(nan_location_bal, 0)) or (not np.allclose(inf_location_bal, 0)): |
720 |
| - print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_bal), ' Inf count: ', |
721 |
| - np.sum(inf_location_bal)) |
722 |
| - print('Skip epoch ', epoch) |
723 |
| - torch.save(pca_features_te_bal, 'te_pca_NaN_%d_bal.pth.tar' % epoch) |
724 |
| - torch.save(features_te_bal, 'te_feature_NaN_%d_bal.pth.tar' % epoch) |
725 |
| - continue |
726 |
| - |
727 |
| - # save patches per epochs |
728 |
| - cp_epoch_out_bal = [features_te_bal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_bal, |
729 |
| - labels_te_bal] |
| 776 | + if (epoch % args.save_epoch == 0): |
| 777 | + out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster) |
| 778 | + out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster) |
730 | 779 |
|
| 780 | + '''EarlyStopping''' |
| 781 | + if early_stopping.check(loss_collect[7], epoch): |
| 782 | + break |
731 | 783 |
|
732 |
| - if (epoch % args.save_epoch == 0): |
733 |
| - with open(os.path.join(args.exp, 'bal', 'features', 'cp_epoch_%d_te_bal.pickle' % epoch), "wb") as f: |
734 |
| - pickle.dump(cp_epoch_out_bal, f) |
735 |
| - with open(os.path.join(args.exp, 'bal', 'pca_features', 'pca_epoch_%d_te_bal.pickle' % epoch), "wb") as f: |
736 |
| - pickle.dump(pca_features_te_bal, f) |
| 784 | + out = produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster) |
| 785 | + out = produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, |
| 786 | + deepcluster) |
737 | 787 |
|
738 | 788 |
|
739 | 789 | '''
|
740 | 790 | ############################
|
741 | 791 | ############################
|
742 |
| - # PSEUDO-LABEL GEN: Test set (Unbalanced UA) |
| 792 | + # PSEUDO-LABEL GEN: Test set (balanced UA) |
743 | 793 | ############################
|
744 | 794 | ############################
|
745 | 795 | '''
|
746 |
| - model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1] |
747 |
| - model.cluster_layer = None |
748 |
| - model.category_layer = None |
749 |
| - |
750 |
| - print('TEST set: Cluster the features') |
751 |
| - features_te_unbal, input_tensors_te_unbal, labels_te_unbal = compute_features(dataloader_test_unbal, model, len(dataset_test_unbal), |
752 |
| - device=device, args=args) |
753 |
| - clustering_loss_te_unbal, pca_features_te_unbal = deepcluster.cluster(features_te_unbal, verbose=args.verbose) |
754 |
| - |
755 |
| - mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end |
756 |
| - mlp.append(nn.ReLU(inplace=True).to(device)) |
757 |
| - model.classifier = nn.Sequential(*mlp) |
758 |
| - model.classifier.to(device) |
759 |
| - |
760 |
| - nan_location_unbal = np.isnan(pca_features_te_unbal) |
761 |
| - inf_location_unbal = np.isinf(pca_features_te_unbal) |
762 |
| - if (not np.allclose(nan_location_unbal, 0)) or (not np.allclose(inf_location_unbal, 0)): |
763 |
| - print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_unbal), ' Inf count: ', |
764 |
| - np.sum(inf_location_unbal)) |
765 |
| - print('Skip epoch ', epoch) |
766 |
| - torch.save(pca_features_te_unbal, 'te_pca_NaN_%d_unbal.pth.tar' % epoch) |
767 |
| - torch.save(features_te_unbal, 'te_feature_NaN_%d_unbal.pth.tar' % epoch) |
768 |
| - continue |
769 |
| - |
770 |
| - # save patches per epochs |
771 |
| - cp_epoch_out_unbal = [features_te_unbal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_unbal, |
772 |
| - labels_te_unbal] |
773 | 796 |
|
| 797 | +def produce_test_result_bal(epoch, model, dataloader_test_bal, dataset_test_bal, device, args, deepcluster): |
| 798 | + model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1] |
| 799 | + model.cluster_layer = None |
| 800 | + model.category_layer = None |
774 | 801 |
|
775 |
| - if (epoch % args.save_epoch == 0): |
776 |
| - with open(os.path.join(args.exp, 'unbal', 'features', 'cp_epoch_%d_te_unbal.pickle' % epoch), "wb") as f: |
777 |
| - pickle.dump(cp_epoch_out_unbal, f) |
778 |
| - with open(os.path.join(args.exp, 'unbal', 'pca_features', 'pca_epoch_%d_te_unbal.pickle' % epoch), "wb") as f: |
779 |
| - pickle.dump(pca_features_te_unbal, f) |
780 |
| - |
| 802 | + print('TEST set: Cluster the features') |
| 803 | + features_te_bal, input_tensors_te_bal, labels_te_bal = compute_features(dataloader_test_bal, model, len(dataset_test_bal), |
| 804 | + device=device, args=args) |
| 805 | + clustering_loss_te_bal, pca_features_te_bal = deepcluster.cluster(features_te_bal, verbose=args.verbose) |
| 806 | + |
| 807 | + mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end |
| 808 | + mlp.append(nn.ReLU(inplace=True).to(device)) |
| 809 | + model.classifier = nn.Sequential(*mlp) |
| 810 | + model.classifier.to(device) |
| 811 | + |
| 812 | + # nan_location_bal = np.isnan(pca_features_te_bal) |
| 813 | + # inf_location_bal = np.isinf(pca_features_te_bal) |
| 814 | + # if (not np.allclose(nan_location_bal, 0)) or (not np.allclose(inf_location_bal, 0)): |
| 815 | + # print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_bal), ' Inf count: ', |
| 816 | + # np.sum(inf_location_bal)) |
| 817 | + # print('Skip epoch ', epoch) |
| 818 | + # torch.save(pca_features_te_bal, 'te_pca_NaN_%d_bal.pth.tar' % epoch) |
| 819 | + # torch.save(features_te_bal, 'te_feature_NaN_%d_bal.pth.tar' % epoch) |
| 820 | + # continue |
| 821 | + |
| 822 | + # save patches per epochs |
| 823 | + cp_epoch_out_bal = [features_te_bal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_bal, |
| 824 | + labels_te_bal] |
| 825 | + with open(os.path.join(args.exp, 'bal', 'features', 'cp_epoch_%d_te_bal.pickle' % epoch), "wb") as f: |
| 826 | + pickle.dump(cp_epoch_out_bal, f) |
| 827 | + with open(os.path.join(args.exp, 'bal', 'pca_features', 'pca_epoch_%d_te_bal.pickle' % epoch), "wb") as f: |
| 828 | + pickle.dump(pca_features_te_bal, f) |
| 829 | + return 0 |
| 830 | + |
| 831 | +def produce_test_result_unbal(epoch, model, dataloader_test_unbal, dataset_test_unbal, device, args, deepcluster): |
| 832 | + model.classifier = nn.Sequential(*list(model.classifier.children())[:-1]) # remove ReLU at classifier [:-1] |
| 833 | + model.cluster_layer = None |
| 834 | + model.category_layer = None |
781 | 835 |
|
| 836 | + print('TEST set: Cluster the features') |
| 837 | + features_te_unbal, input_tensors_te_unbal, labels_te_unbal = compute_features(dataloader_test_unbal, model, len(dataset_test_unbal), |
| 838 | + device=device, args=args) |
| 839 | + clustering_loss_te_unbal, pca_features_te_unbal = deepcluster.cluster(features_te_unbal, verbose=args.verbose) |
| 840 | + |
| 841 | + mlp = list(model.classifier.children()) # classifier that ends with linear(512 * 128). No ReLU at the end |
| 842 | + mlp.append(nn.ReLU(inplace=True).to(device)) |
| 843 | + model.classifier = nn.Sequential(*mlp) |
| 844 | + model.classifier.to(device) |
| 845 | + |
| 846 | + # nan_location_unbal = np.isnan(pca_features_te_unbal) |
| 847 | + # inf_location_unbal = np.isinf(pca_features_te_unbal) |
| 848 | + # if (not np.allclose(nan_location_unbal, 0)) or (not np.allclose(inf_location_unbal, 0)): |
| 849 | + # print('PCA: Feature NaN or Inf found. Nan count: ', np.sum(nan_location_unbal), ' Inf count: ', |
| 850 | + # np.sum(inf_location_unbal)) |
| 851 | + # print('Skip epoch ', epoch) |
| 852 | + # torch.save(pca_features_te_unbal, 'te_pca_NaN_%d_unbal.pth.tar' % epoch) |
| 853 | + # torch.save(features_te_unbal, 'te_feature_NaN_%d_unbal.pth.tar' % epoch) |
| 854 | + # continue |
| 855 | + |
| 856 | + # save patches per epochs |
| 857 | + cp_epoch_out_unbal = [features_te_unbal, deepcluster.images_lists, deepcluster.images_dist_lists, input_tensors_te_unbal, |
| 858 | + labels_te_unbal] |
| 859 | + |
| 860 | + with open(os.path.join(args.exp, 'unbal', 'features', 'cp_epoch_%d_te_unbal.pickle' % epoch), "wb") as f: |
| 861 | + pickle.dump(cp_epoch_out_unbal, f) |
| 862 | + with open(os.path.join(args.exp, 'unbal', 'pca_features', 'pca_epoch_%d_te_unbal.pickle' % epoch), "wb") as f: |
| 863 | + pickle.dump(pca_features_te_unbal, f) |
| 864 | + return 0 |
782 | 865 |
|
783 | 866 | if __name__ == '__main__':
|
784 | 867 | args = parse_args()
|
|
0 commit comments