Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CLI arg handling #488

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ def _add_arguments_common(parser):
help="Model name for well-known models",
)


def add_arguments(parser):
# TODO: Refactor this so that only common options are here
# and command-specific options are inside individual
# add_arguments_for_generate, add_arguments_for_export etc.

parser.add_argument(
"--chat",
action="store_true",
Expand Down Expand Up @@ -301,10 +295,10 @@ def add_arguments(parser):


def arg_init(args):
if Path(args.quantize).is_file():
if hasattr(args, 'quantize') and Path(args.quantize).is_file():
with open(args.quantize, "r") as f:
args.quantize = json.loads(f.read())

if args.seed:
if hasattr(args, 'seed') and args.seed:
torch.manual_seed(args.seed)
return args
3 changes: 1 addition & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from build.model import Transformer
from build.utils import set_precision
from cli import add_arguments, add_arguments_for_eval, arg_init
from cli import add_arguments_for_eval, arg_init
from generate import encode_tokens, model_forward

torch._dynamo.config.automatic_dynamic_shapes = True
Expand Down Expand Up @@ -289,7 +289,6 @@ def main(args) -> None:

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="torchchat eval CLI")
add_arguments(parser)
add_arguments_for_eval(parser)
args = parser.parse_args()
args = arg_init(args)
Expand Down
3 changes: 1 addition & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)

from build.utils import set_backend, set_precision
from cli import add_arguments, add_arguments_for_export, arg_init, check_args
from cli import add_arguments_for_export, arg_init, check_args
from export_aoti import export_model as export_model_aoti

try:
Expand Down Expand Up @@ -104,7 +104,6 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="torchchat export CLI")
add_arguments(parser)
add_arguments_for_export(parser)
args = parser.parse_args()
check_args(args, "export")
Expand Down
3 changes: 1 addition & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from build.model import Transformer
from build.utils import device_sync, set_precision
from cli import add_arguments, add_arguments_for_generate, arg_init, check_args
from cli import add_arguments_for_generate, arg_init, check_args

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -710,7 +710,6 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="torchchat generate CLI")
add_arguments(parser)
add_arguments_for_generate(parser)
args = parser.parse_args()
check_args(args, "generate")
Expand Down
20 changes: 1 addition & 19 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import sys

from cli import (
add_arguments,
add_arguments_for_browser,
add_arguments_for_chat,
add_arguments_for_download,
Expand All @@ -30,17 +29,14 @@
# Initialize the top-level parser
parser = argparse.ArgumentParser(
prog="torchchat",
description="Welcome to the torchchat CLI!",
add_help=True,
)
# Default command is to print help
parser.set_defaults(func=parser.print_help())

add_arguments(parser)
subparsers = parser.add_subparsers(
dest="command",
help="The specific command to run",
)
subparsers.required = True

parser_chat = subparsers.add_parser(
"chat",
Expand Down Expand Up @@ -90,20 +86,6 @@
)
add_arguments_for_remove(parser_remove)

# Move all flags to the front of sys.argv since we don't
# want to use the subparser syntax
flag_args = []
positional_args = []
i = 1
while i < len(sys.argv):
if sys.argv[i].startswith("-"):
flag_args += sys.argv[i : i + 2]
i += 2
else:
positional_args.append(sys.argv[i])
i += 1
sys.argv = sys.argv[:1] + flag_args + positional_args

# Now parse the arguments
args = parser.parse_args()
args = arg_init(args)
Expand Down
Loading