-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun_experiments_sequential.py
111 lines (92 loc) · 3.77 KB
/
run_experiments_sequential.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import time
import logging
import sys
import os
import pickle
import numpy as np
import copy
from params import *
def run_experiments(args, save_dir):
"""
currently we need to set the search space as an environment variable
so that data.py knows which Cell class to import.
TODO: handle this better by making Cell subclasses for each search space
"""
os.environ['search_space'] = args.search_space
from nas_algorithms import run_nas_algorithm
from data import Data
# set up arguments
trials = args.trials
out_file = args.output_filename
save_specs = args.save_specs
metann_params = meta_neuralnet_params(args.search_space)
algorithm_params = algo_params(args.algo_params)
num_algos = len(algorithm_params)
logging.info(algorithm_params)
# set up search space
mp = copy.deepcopy(metann_params)
ss = mp.pop('search_space')
dataset = mp.pop('dataset')
search_space = Data(ss, dataset=dataset)
for i in range(trials):
results = []
walltimes = []
run_data = []
for j in range(num_algos):
# run NAS algorithm
print('\n* Running NAS algorithm: {}'.format(algorithm_params[j]))
starttime = time.time()
algo_result, run_datum = run_nas_algorithm(algorithm_params[j], search_space, mp)
algo_result = np.round(algo_result, 5)
# remove unnecessary dict entries that take up space
for d in run_datum:
if not save_specs:
d.pop('spec')
for key in ['encoding', 'adj', 'path', 'dist_to_min']:
if key in d:
d.pop(key)
# add walltime, results, run_data
walltimes.append(time.time()-starttime)
results.append(algo_result)
run_data.append(run_datum)
# print and pickle results
filename = os.path.join(save_dir, '{}_{}.pkl'.format(out_file, i))
print('\n* Trial summary: (params, results, walltimes)')
print(algorithm_params)
print(metann_params)
print(results)
print(walltimes)
print('\n* Saving to file {}'.format(filename))
with open(filename, 'wb') as f:
pickle.dump([algorithm_params, metann_params, results, walltimes, run_data], f)
f.close()
def main(args):
# make save directory
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.mkdir(save_dir)
algo_params = args.algo_params
save_path = save_dir + '/' + algo_params + '/'
if not os.path.exists(save_path):
os.mkdir(save_path)
# set up logging
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
logging.info(args)
run_experiments(args, save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Args for BANANAS experiments')
parser.add_argument('--trials', type=int, default=500, help='Number of trials')
parser.add_argument('--search_space', type=str, default='nasbench', \
help='nasbench or darts')
parser.add_argument('--algo_params', type=str, default='main_experiments', help='which parameters to use')
parser.add_argument('--output_filename', type=str, default='round', help='name of output files')
parser.add_argument('--save_dir', type=str, default='results_output', help='name of save directory')
parser.add_argument('--save_specs', type=bool, default=False, help='save the architecture specs')
args = parser.parse_args()
main(args)