File tree 3 files changed +4
-14
lines changed
3 files changed +4
-14
lines changed Original file line number Diff line number Diff line change @@ -464,10 +464,10 @@ def _load_sbert_model(self, model_path):
464
464
465
465
modules = OrderedDict ()
466
466
for module_config in modules_config :
467
- if module_config ['type ' ]== "sentence_transformers.models.Transformer" :
467
+ if module_config ['idx ' ]== 0 :
468
468
print ('load INSTRUCTOR_Transformer' )
469
469
module_class = INSTRUCTOR_Transformer
470
- elif module_config ['type ' ]== "sentence_transformers.models.Pooling" :
470
+ elif module_config ['idx ' ]== 1 :
471
471
module_class = INSTRUCTOR_Pooling
472
472
else :
473
473
module_class = import_from_string (module_config ['type' ])
Original file line number Diff line number Diff line change @@ -555,18 +555,8 @@ def preprocess_function(examples):
555
555
checkpoint = training_args .resume_from_checkpoint
556
556
elif last_checkpoint is not None :
557
557
checkpoint = last_checkpoint
558
- train_result = trainer .train (resume_from_checkpoint = checkpoint )
559
- trainer .save_model () # Saves the tokenizer too for easy upload
560
-
561
- metrics = train_result .metrics
562
- max_train_samples = (
563
- data_args .max_train_samples if data_args .max_train_samples is not None else len (train_dataset )
564
- )
565
- metrics ["train_samples" ] = min (max_train_samples , len (train_dataset ))
566
-
567
- trainer .log_metrics ("train" , metrics )
568
- trainer .save_metrics ("train" , metrics )
569
- trainer .save_state ()
558
+ trainer .train (resume_from_checkpoint = checkpoint )
559
+ trainer .model .save (training_args .output_dir )
570
560
571
561
572
562
def _mp_fn (index ):
You can’t perform that action at this time.
0 commit comments