13
13
14
14
import transformers
15
15
from filelock import FileLock
16
- from InstructorEmbedding import INSTRUCTOR
16
+ from InstructorEmbedding import Instructor , InstructorTransformer
17
17
from transformers import (
18
18
AutoTokenizer ,
19
19
DataCollatorForSeq2Seq ,
27
27
set_seed ,
28
28
)
29
29
from transformers .trainer_utils import get_last_checkpoint
30
+ from transformers .trainer_callback import TrainerCallback , TrainerState , TrainerControl
31
+ from transformers .training_args import TrainingArguments
32
+
30
33
from transformers .utils import check_min_version , is_offline_mode
31
34
from torch .utils .data import Dataset , SequentialSampler
32
35
from torch .utils .data .distributed import DistributedSampler
@@ -100,7 +103,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
100
103
cur_inputs = {
101
104
'input_ids' : inputs [f'{ k } _input_ids' ],
102
105
'attention_mask' : inputs [f'{ k } _attention_mask' ],
103
- 'context_masks ' : inputs [f'{ k } _context_masks ' ],
106
+ 'instruction_mask ' : inputs [f'{ k } _instruction_mask ' ],
104
107
}
105
108
cur_results [k ] = model (cur_inputs )['sentence_embedding' ]
106
109
embeddings_query = cur_results ['query' ]
@@ -156,7 +159,6 @@ class ModelArguments:
156
159
"""
157
160
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
158
161
"""
159
-
160
162
model_name_or_path : str = field (
161
163
metadata = {"help" : "Path to pretrained model or model identifier from huggingface.co/models" }
162
164
)
@@ -424,13 +426,8 @@ def main():
424
426
)
425
427
426
428
# Set seed before initializing model.
427
- tokenizer = AutoTokenizer .from_pretrained (
428
- model_args .tokenizer_name if model_args .tokenizer_name else model_args .model_name_or_path ,
429
- cache_dir = model_args .cache_dir ,
430
- use_fast = model_args .use_fast_tokenizer ,
431
- revision = model_args .model_revision ,
432
- use_auth_token = True if model_args .use_auth_token else None ,
433
- )
429
+ instructor_tokenizer = InstructorTransformer (model_name_or_path = model_args .model_name_or_path , load_model = False )
430
+ tokenizer = instructor_tokenizer .tokenizer #pre-trained tokentizer
434
431
435
432
set_seed (training_args .seed )
436
433
with open (os .path .join (model_args .cache_dir , 'medi-data.json' )) as f :
@@ -443,7 +440,7 @@ def main():
443
440
444
441
real_batch_size = max (training_args .per_device_train_batch_size ,
445
442
training_args .per_device_train_batch_size * torch .cuda .device_count ())
446
- # print('real_batch_size: ', real_batch_size,training_args.per_device_train_batch_size,torch.cuda.device_count())
443
+
447
444
def get_examples_raw (old_examples_raw , total_n , real_batch_size ):
448
445
examples_raw = []
449
446
for idx in range (0 , total_n , real_batch_size ):
@@ -485,13 +482,11 @@ def get_dataset(examples_raw):
485
482
for i in range (total_num ):
486
483
cur_e = examples_raw [i ]
487
484
for k in ['query' ,'pos' ,'neg' ]:
488
- for s in cur_e [k ][:- 1 ]:
489
- assert not '!@#$%^&**!@#$%^&**' in s
490
485
cur_e [k ][- 1 ] = str (cur_e [k ][- 1 ])
491
486
if not data_args .add_prompt_to_document :
492
487
cur_e [k ][0 ] = ''
493
488
assert cur_e [k ][0 ].startswith ('Represent ' ) or cur_e [k ][0 ]== ''
494
- examples [k ].append ('!@#$%^&**!@#$%^&**' . join ( cur_e [k ]) )
489
+ examples [k ].append (cur_e [k ])
495
490
if not cur_e ['task_id' ] in task_name_map :
496
491
task_name_map [cur_e ['task_id' ]] = task_count
497
492
task_count += 1
@@ -500,36 +495,20 @@ def get_dataset(examples_raw):
500
495
501
496
train_raw_datasets = DatasetDict ({'train' :Dataset .from_dict (get_dataset (train_examples_raw ))})
502
497
503
- model = INSTRUCTOR (real_name_or_path , cache_folder = model_args .cache_dir )
498
+ model = Instructor (real_name_or_path , cache_folder = model_args .cache_dir )
504
499
column_names = train_raw_datasets ["train" ].column_names
505
500
506
501
def preprocess_function (examples ):
507
502
all_tokenized = None
508
503
for key in ['query' ,'pos' ,'neg' ]:
509
- num = len (examples [key ])
510
- contexts = []
511
- concatenated_input_texts = []
512
- for local_idx in range (num ):
513
- splits = examples [key ][local_idx ].split ('!@#$%^&**!@#$%^&**' )
514
- assert len (splits ) == 2
515
- contexts .append (splits [0 ])
516
- concatenated_input_texts .append ('' .join (splits ))
517
- assert isinstance (contexts [- 1 ], str )
518
- assert isinstance (concatenated_input_texts [- 1 ], str )
519
- tokenized = tokenizer (concatenated_input_texts ,padding = 'max_length' , truncation = 'longest_first' , return_tensors = "pt" , max_length = data_args .max_source_length )
520
- context_tok = tokenizer (contexts ,padding = 'max_length' , truncation = 'longest_first' , return_tensors = "pt" , max_length = data_args .max_source_length )
521
- tokenized ['context_masks' ] = torch .sum (context_tok ['attention_mask' ], dim = 1 )
522
- tokenized ['context_masks' ] = tokenized ['context_masks' ] - 1
523
- for my_idx in range (len (tokenized ['context_masks' ])):
524
- if tokenized ['context_masks' ][my_idx ] <= 1 :
525
- tokenized ['context_masks' ][my_idx ] = 0
526
- keys = tokenized .keys ()
504
+ input_features = instructor_tokenizer .tokenize (examples [key ])
505
+ keys = input_features .keys ()
527
506
if all_tokenized is None :
528
- all_tokenized = tokenized .copy ()
507
+ all_tokenized = input_features .copy ()
529
508
for k in keys :
530
509
all_tokenized [k ] = all_tokenized [k ].tolist ()
531
510
for k in keys :
532
- all_tokenized [f'{ key } _{ k } ' ] = tokenized [k ].tolist ()
511
+ all_tokenized [f'{ key } _{ k } ' ] = input_features [k ].tolist ()
533
512
all_tokenized ['task_id' ] = examples ['task_id' ]
534
513
return all_tokenized
535
514
0 commit comments