Skip to content

Commit d3b2b55

Browse files
frankfliuzachgk
andauthored
[huggingface] Allows retry failed model in model converter (#1989)
* [huggingface] Allows retry failed model in model converter Co-authored-by: Zach Kimberg <zachary@kimberg.com>
1 parent bce07da commit d3b2b55

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
__pycache__
22
model/
3-
processed_models.json
3+
tmp/
4+
models.json

extensions/tokenizers/src/main/python/arg_parser.py

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def converter_args():
2323
default=1,
2424
help="Max amount of models to convert")
2525
parser.add_argument("-o", "--output-dir", help="Model output directory")
26+
parser.add_argument("-r",
27+
"--retry-failed",
28+
action='store_true',
29+
help="Retry failed model")
2630
group = parser.add_mutually_exclusive_group(required=True)
2731
group.add_argument(
2832
"-c",

extensions/tokenizers/src/main/python/huggingface_models.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"ForTokenClassification": "token-classification",
2626
"ForSequenceClassification": "text-classification",
2727
"ForMultipleChoice": "text-classification",
28-
"ForMaskedLM": "fill-mask"
28+
"ForMaskedLM": "fill-mask",
2929
}
3030
LANGUAGES = ModelSearchArguments().language
3131

@@ -48,7 +48,7 @@ def __init__(self, output_dir: str):
4848
self.output_dir = output_dir
4949
self.processed_models = {}
5050

51-
output_path = os.path.join(output_dir, "processed_models.json")
51+
output_path = os.path.join(output_dir, "models.json")
5252
if os.path.exists(output_path):
5353
with open(output_path, "r") as f:
5454
self.processed_models = json.load(f)
@@ -90,8 +90,9 @@ def list_models(self, args: Namespace) -> List[dict]:
9090
existing_model = self.processed_models.get(model_id)
9191
if existing_model:
9292
existing_model["downloads"] = model_info.downloads
93-
logging.info(f"Skip converted mode: {model_id}.")
94-
continue
93+
if not args.retry_failed:
94+
logging.info(f"Skip converted model: {model_id}.")
95+
continue
9596

9697
try:
9798
config = hf_hub_download(repo_id=model_id,

0 commit comments

Comments
 (0)