Skip to content

Commit

Permalink
Update CLI arg handling
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Apr 26, 2024
1 parent 9c91667 commit 52c3d91
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 20 deletions.
7 changes: 5 additions & 2 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def _add_arguments_common(parser):
help="Model name for well-known models",
)

add_arguments(parser)


def add_arguments(parser):
# TODO: Refactor this so that only common options are here
Expand Down Expand Up @@ -136,6 +138,7 @@ def add_arguments(parser):
)
parser.add_argument(
"--compile",
default=False,
action="store_true",
help="Whether to compile the model with torch.compile",
)
Expand Down Expand Up @@ -301,10 +304,10 @@ def add_arguments(parser):


def arg_init(args):
if Path(args.quantize).is_file():
if hasattr(args, 'quantize') and Path(args.quantize).is_file():
with open(args.quantize, "r") as f:
args.quantize = json.loads(f.read())

if args.seed:
if hasattr(args, 'seed') and args.seed:
torch.manual_seed(args.seed)
return args
19 changes: 1 addition & 18 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,14 @@
# Initialize the top-level parser
parser = argparse.ArgumentParser(
prog="torchchat",
description="Welcome to the torchchat CLI!",
add_help=True,
)
# Default command is to print help
parser.set_defaults(func=parser.print_help())

add_arguments(parser)
subparsers = parser.add_subparsers(
dest="command",
help="The specific command to run",
)
subparsers.required = True

parser_chat = subparsers.add_parser(
"chat",
Expand Down Expand Up @@ -90,20 +87,6 @@
)
add_arguments_for_remove(parser_remove)

# Move all flags to the front of sys.argv since we don't
# want to use the subparser syntax
flag_args = []
positional_args = []
i = 1
while i < len(sys.argv):
if sys.argv[i].startswith("-"):
flag_args += sys.argv[i : i + 2]
i += 2
else:
positional_args.append(sys.argv[i])
i += 1
sys.argv = sys.argv[:1] + flag_args + positional_args

# Now parse the arguments
args = parser.parse_args()
args = arg_init(args)
Expand Down

0 comments on commit 52c3d91

Please sign in to comment.