@@ -23,7 +23,7 @@ def batch_to_device(batch, target_device: str):
23
23
return batch
24
24
25
25
26
- class INSTRUCTOR_Pooling (nn .Module ):
26
+ class INSTRUCTORPooling (nn .Module ):
27
27
"""Performs pooling (max or mean) on the token embeddings.
28
28
29
29
Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding.
@@ -245,7 +245,7 @@ def load(input_path):
245
245
) as config_file :
246
246
config = json .load (config_file )
247
247
248
- return INSTRUCTOR_Pooling (** config )
248
+ return INSTRUCTORPooling (** config )
249
249
250
250
251
251
def import_from_string (dotted_path ):
@@ -536,13 +536,6 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=
536
536
"cache_dir" : cache_folder ,
537
537
"tqdm_class" : disabled_tqdm ,
538
538
}
539
- # Try to download from the remote
540
- try :
541
- model_path = snapshot_download (** download_kwargs )
542
- except Exception :
543
- # Otherwise, try local (i.e. cache) only
544
- download_kwargs ["local_files_only" ] = True
545
- model_path = snapshot_download (** download_kwargs )
546
539
547
540
# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
548
541
config_sentence_transformers_json_path = os .path .join (
@@ -573,7 +566,7 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=
573
566
if module_config ["idx" ] == 0 :
574
567
module_class = INSTRUCTORTransformer
575
568
elif module_config ["idx" ] == 1 :
576
- module_class = INSTRUCTOR_Pooling
569
+ module_class = INSTRUCTORPooling
577
570
else :
578
571
module_class = import_from_string (module_config ["type" ])
579
572
module = module_class .load (os .path .join (model_path , module_config ["path" ]))
0 commit comments