Skip to content

Commit 619dda2

Browse files
committed
fix loading
1 parent a453a36 commit 619dda2

File tree

3 files changed

+4
-14
lines changed

3 files changed

+4
-14
lines changed

.DS_Store

0 Bytes
Binary file not shown.

InstructorEmbedding/instructor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -464,10 +464,10 @@ def _load_sbert_model(self, model_path):
464464

465465
modules = OrderedDict()
466466
for module_config in modules_config:
467-
if module_config['type']=="sentence_transformers.models.Transformer":
467+
if module_config['idx']==0:
468468
print('load INSTRUCTOR_Transformer')
469469
module_class = INSTRUCTOR_Transformer
470-
elif module_config['type']=="sentence_transformers.models.Pooling":
470+
elif module_config['idx']==1:
471471
module_class = INSTRUCTOR_Pooling
472472
else:
473473
module_class = import_from_string(module_config['type'])

train.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -555,18 +555,8 @@ def preprocess_function(examples):
555555
checkpoint = training_args.resume_from_checkpoint
556556
elif last_checkpoint is not None:
557557
checkpoint = last_checkpoint
558-
train_result = trainer.train(resume_from_checkpoint=checkpoint)
559-
trainer.save_model() # Saves the tokenizer too for easy upload
560-
561-
metrics = train_result.metrics
562-
max_train_samples = (
563-
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
564-
)
565-
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
566-
567-
trainer.log_metrics("train", metrics)
568-
trainer.save_metrics("train", metrics)
569-
trainer.save_state()
558+
trainer.train(resume_from_checkpoint=checkpoint)
559+
trainer.model.save(training_args.output_dir)
570560

571561

572562
def _mp_fn(index):

0 commit comments

Comments
 (0)