Skip to content

Commit 856211d

Browse files
author
ashok.b
committed
modifed masking before pooling
1 parent e749023 commit 856211d

File tree

7 files changed

+414
-330
lines changed

7 files changed

+414
-330
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
.idea/
2+
/cache
3+
/evaluation/MTEB/mteb.egg-info
4+
/**/__pycache__
5+
/InstructorEmbedding.egg-info

InstructorEmbedding/instructor.py

+376-275
Large diffs are not rendered by default.

evaluation/MTEB/examples/evaluate_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import argparse
55
from mteb import MTEB
6-
from InstructorEmbedding import INSTRUCTOR
6+
from InstructorEmbedding import Instructor
77
if __name__ == '__main__':
88
logging.basicConfig(level=logging.INFO)
99
parser = argparse.ArgumentParser()
@@ -24,7 +24,7 @@
2424
# from functools import partialmethod
2525
#
2626
# tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
27-
model = INSTRUCTOR(args.model_name,cache_folder=args.cache_dir)
27+
model = Instructor(args.model_name,cache_folder=args.cache_dir)
2828
evaluation = MTEB(tasks=[args.task_name],task_langs=["en"])
2929
evaluation.run(model, output_folder=args.output_dir, eval_splits=[args.split],args=args,)
3030

evaluation/MTEB/mteb/abstasks/AbsTaskRetrieval.py

+15-18
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def evaluate(
597597
model,
598598
split="test",
599599
batch_size=128,
600-
corpus_chunk_size=None,
600+
corpus_chunk_size=50000,
601601
target_devices=None,
602602
score_function="cos_sim",
603603
**kwargs
@@ -708,7 +708,7 @@ def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
708708
instruction = DEFINITIONS[self.args.prompt][self.args.task_name]['query']
709709
if self.args.prompt:
710710
for s in queries:
711-
new_sentences.append([instruction, s, 0])
711+
new_sentences.append([instruction, s])
712712
else:
713713
new_sentences = queries
714714

@@ -717,7 +717,6 @@ def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
717717

718718
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
719719
self.count += 1
720-
# print('count: ',self.count)
721720
if type(corpus) is dict:
722721
sentences = [
723722
(corpus["title"][i] + ' ' + corpus["text"][i]).strip()
@@ -733,28 +732,26 @@ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs)
733732
new_sentences = []
734733
instruction = DEFINITIONS[self.args.prompt][self.args.task_name]['corpus']
735734
for s in sentences:
736-
new_sentences.append([instruction, s, 0])
737-
# kwargs['show_progress_bar'] = False
738-
return self.model.encode(sentences, batch_size=128, **kwargs)
735+
new_sentences.append([instruction, s])
736+
return self.model.encode(new_sentences, batch_size=128, **kwargs)
739737

740738
def encode_corpus_parallel(
741739
self, corpus: List[Dict[str, str]], pool: Dict[str, object], batch_size: int, chunk_id: int, **kwargs
742740
):
741+
sentences = []
743742
instruction = DEFINITIONS[self.args.prompt][self.args.task_name]['corpus']
744743
if type(corpus) is dict:
745-
sentences = [
746-
[instruction, (corpus["title"][i] + self.sep + corpus["text"][i]).strip()]
747-
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
748-
if "title" in corpus
749-
else corpus["text"][i].strip()
750-
for i in range(len(corpus["text"]))
751-
]
744+
for i in range(len(corpus["text"])):
745+
sentence = corpus["text"][i].strip()
746+
if "title" in corpus:
747+
sentence = corpus["title"][i].strip() + self.sep + sentence
748+
sentences.append([instruction, sentence])
752749
else:
753-
sentences = [
754-
[instruction, (doc["title"] + self.sep + doc["text"]).strip()]
755-
(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
756-
for doc in corpus
757-
]
750+
for doc in corpus:
751+
sentence = doc["text"].strip()
752+
if "title" in doc:
753+
sentence = doc["title"].strip() + self.sep + sentence
754+
sentences.append([instruction, sentence])
758755

759756
if chunk_id is not None and chunk_id >= len(pool["processes"]):
760757
output_queue = pool["output"]

evaluation/MTEB/setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
"torch",
8585
"tqdm",
8686
"rich",
87+
"beir",
88+
"evaluate==0.2.0"
8789
],
8890
extras_require=extras,
8991
classifiers=[

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ sentence_transformers>=2.2.0
1010
torch
1111
tqdm
1212
rich
13+
tensorboard

train.py

+14-35
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import transformers
1515
from filelock import FileLock
16-
from InstructorEmbedding import INSTRUCTOR
16+
from InstructorEmbedding import Instructor, InstructorTransformer
1717
from transformers import (
1818
AutoTokenizer,
1919
DataCollatorForSeq2Seq,
@@ -27,6 +27,9 @@
2727
set_seed,
2828
)
2929
from transformers.trainer_utils import get_last_checkpoint
30+
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
31+
from transformers.training_args import TrainingArguments
32+
3033
from transformers.utils import check_min_version, is_offline_mode
3134
from torch.utils.data import Dataset, SequentialSampler
3235
from torch.utils.data.distributed import DistributedSampler
@@ -100,7 +103,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
100103
cur_inputs = {
101104
'input_ids': inputs[f'{k}_input_ids'],
102105
'attention_mask': inputs[f'{k}_attention_mask'],
103-
'context_masks': inputs[f'{k}_context_masks'],
106+
'instruction_mask': inputs[f'{k}_instruction_mask'],
104107
}
105108
cur_results[k] = model(cur_inputs)['sentence_embedding']
106109
embeddings_query = cur_results['query']
@@ -156,7 +159,6 @@ class ModelArguments:
156159
"""
157160
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
158161
"""
159-
160162
model_name_or_path: str = field(
161163
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
162164
)
@@ -424,13 +426,8 @@ def main():
424426
)
425427

426428
# Set seed before initializing model.
427-
tokenizer = AutoTokenizer.from_pretrained(
428-
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
429-
cache_dir=model_args.cache_dir,
430-
use_fast=model_args.use_fast_tokenizer,
431-
revision=model_args.model_revision,
432-
use_auth_token=True if model_args.use_auth_token else None,
433-
)
429+
instructor_tokenizer = InstructorTransformer(model_name_or_path=model_args.model_name_or_path, load_model=False)
430+
tokenizer = instructor_tokenizer.tokenizer #pre-trained tokentizer
434431

435432
set_seed(training_args.seed)
436433
with open(os.path.join(model_args.cache_dir, 'medi-data.json')) as f:
@@ -443,7 +440,7 @@ def main():
443440

444441
real_batch_size = max(training_args.per_device_train_batch_size,
445442
training_args.per_device_train_batch_size * torch.cuda.device_count())
446-
# print('real_batch_size: ', real_batch_size,training_args.per_device_train_batch_size,torch.cuda.device_count())
443+
447444
def get_examples_raw(old_examples_raw, total_n, real_batch_size):
448445
examples_raw = []
449446
for idx in range(0, total_n, real_batch_size):
@@ -485,13 +482,11 @@ def get_dataset(examples_raw):
485482
for i in range(total_num):
486483
cur_e = examples_raw[i]
487484
for k in ['query','pos','neg']:
488-
for s in cur_e[k][:-1]:
489-
assert not '!@#$%^&**!@#$%^&**' in s
490485
cur_e[k][-1] = str(cur_e[k][-1])
491486
if not data_args.add_prompt_to_document:
492487
cur_e[k][0] = ''
493488
assert cur_e[k][0].startswith('Represent ') or cur_e[k][0]==''
494-
examples[k].append('!@#$%^&**!@#$%^&**'.join(cur_e[k]))
489+
examples[k].append(cur_e[k])
495490
if not cur_e['task_id'] in task_name_map:
496491
task_name_map[cur_e['task_id']] = task_count
497492
task_count += 1
@@ -500,36 +495,20 @@ def get_dataset(examples_raw):
500495

501496
train_raw_datasets = DatasetDict({'train':Dataset.from_dict(get_dataset(train_examples_raw))})
502497

503-
model = INSTRUCTOR(real_name_or_path, cache_folder=model_args.cache_dir)
498+
model = Instructor(real_name_or_path, cache_folder=model_args.cache_dir)
504499
column_names = train_raw_datasets["train"].column_names
505500

506501
def preprocess_function(examples):
507502
all_tokenized = None
508503
for key in ['query','pos','neg']:
509-
num = len(examples[key])
510-
contexts = []
511-
concatenated_input_texts = []
512-
for local_idx in range(num):
513-
splits = examples[key][local_idx].split('!@#$%^&**!@#$%^&**')
514-
assert len(splits) == 2
515-
contexts.append(splits[0])
516-
concatenated_input_texts.append(''.join(splits))
517-
assert isinstance(contexts[-1], str)
518-
assert isinstance(concatenated_input_texts[-1], str)
519-
tokenized = tokenizer(concatenated_input_texts,padding='max_length', truncation='longest_first', return_tensors="pt", max_length=data_args.max_source_length)
520-
context_tok = tokenizer(contexts,padding='max_length', truncation='longest_first', return_tensors="pt", max_length=data_args.max_source_length)
521-
tokenized['context_masks'] = torch.sum(context_tok['attention_mask'], dim=1)
522-
tokenized['context_masks'] = tokenized['context_masks'] - 1
523-
for my_idx in range(len(tokenized['context_masks'])):
524-
if tokenized['context_masks'][my_idx] <= 1:
525-
tokenized['context_masks'][my_idx] = 0
526-
keys = tokenized.keys()
504+
input_features = instructor_tokenizer.tokenize(examples[key])
505+
keys = input_features.keys()
527506
if all_tokenized is None:
528-
all_tokenized = tokenized.copy()
507+
all_tokenized = input_features.copy()
529508
for k in keys:
530509
all_tokenized[k] = all_tokenized[k].tolist()
531510
for k in keys:
532-
all_tokenized[f'{key}_{k}'] = tokenized[k].tolist()
511+
all_tokenized[f'{key}_{k}'] = input_features[k].tolist()
533512
all_tokenized['task_id'] = examples['task_id']
534513
return all_tokenized
535514

0 commit comments

Comments
 (0)