Skip to content

Commit ae975c4

Browse files
authored
Update instructor.py
1 parent 5ede346 commit ae975c4

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

InstructorEmbedding/instructor.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def batch_to_device(batch, target_device: str):
2323
return batch
2424

2525

26-
class INSTRUCTOR_Pooling(nn.Module):
26+
class INSTRUCTORPooling(nn.Module):
2727
"""Performs pooling (max or mean) on the token embeddings.
2828
2929
Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding.
@@ -245,7 +245,7 @@ def load(input_path):
245245
) as config_file:
246246
config = json.load(config_file)
247247

248-
return INSTRUCTOR_Pooling(**config)
248+
return INSTRUCTORPooling(**config)
249249

250250

251251
def import_from_string(dotted_path):
@@ -536,13 +536,6 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=
536536
"cache_dir": cache_folder,
537537
"tqdm_class": disabled_tqdm,
538538
}
539-
# Try to download from the remote
540-
try:
541-
model_path = snapshot_download(**download_kwargs)
542-
except Exception:
543-
# Otherwise, try local (i.e. cache) only
544-
download_kwargs["local_files_only"] = True
545-
model_path = snapshot_download(**download_kwargs)
546539

547540
# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
548541
config_sentence_transformers_json_path = os.path.join(
@@ -573,7 +566,7 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=
573566
if module_config["idx"] == 0:
574567
module_class = INSTRUCTORTransformer
575568
elif module_config["idx"] == 1:
576-
module_class = INSTRUCTOR_Pooling
569+
module_class = INSTRUCTORPooling
577570
else:
578571
module_class = import_from_string(module_config["type"])
579572
module = module_class.load(os.path.join(model_path, module_config["path"]))

0 commit comments

Comments
 (0)