Skip to content

Commit ce7d905

Browse files
committed
learn cmd: enable specifying --model-path
1 parent a950343 commit ce7d905

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

whereami/__main__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def get_args_parser():
5454
help='Change the wifi device to use')
5555
learn_parser.add_argument('--num_samples', '-n', type=int,
5656
default=1, help='Number of samples to take')
57+
learn_parser.add_argument('--model_path', '-mp', default=None,
58+
help='The directory of the model / trained data')
5759

5860
rename = subparsers.add_parser('rename')
5961

@@ -77,7 +79,7 @@ def main():
7779
elif args.command == "predict":
7880
print(predict(args.input_path, args.model_path, args.device))
7981
elif args.command == "learn":
80-
learn(args.location, args.num_samples, args.device)
82+
learn(args.location, args.num_samples, args.device, args.model_path)
8183
elif args.command == "crossval":
8284
crossval(path=args.model_path)
8385
elif args.command in ["locations", "ls"]:

whereami/learn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def write_data(label_path, data):
1616
f.write("\n")
1717

1818

19-
def learn(label, n=1, device=""):
20-
path = ensure_whereami_path()
19+
def learn(label, n=1, device="", model_path=None):
20+
path = ensure_whereami_path(model_path)
2121
label_path = get_label_file(path, label + ".txt")
2222
for i in tqdm(range(n)):
2323
if i != 0:

whereami/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def get_whereami_path(path=None):
99
return os.path.expanduser(path)
1010

1111

12-
def ensure_whereami_path():
13-
path = get_whereami_path()
12+
def ensure_whereami_path(path=None):
13+
path = get_whereami_path(path)
1414
if not os.path.exists(path): # pragma: no cover
1515
os.makedirs(path)
1616
return path

0 commit comments

Comments
 (0)