Skip to content

Commit

Permalink
fix num classes in unsup.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Feb 27, 2022
1 parent f61e330 commit 5f7ee24
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/training_unsup_text_matching_model_en.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def load_en_stsb_dataset(stsb_file):
return train_samples, valid_samples, test_samples


def load_en_nli_dataset(nli_file, limit_size=200000):
def load_en_nli_dataset(nli_file, limit_size=100000):
# Load NLI train dataset
if not os.path.exists(nli_file):
http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_file)
Expand Down Expand Up @@ -104,7 +104,7 @@ def main():
parser.add_argument('--num_epochs', default=10, type=int, help='Number of training epochs')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')
parser.add_argument('--learning_rate', default=2e-5, type=float, help='Learning rate')
parser.add_argument('--nli_limit_size', default=200000, type=float, help='Learning rate')
parser.add_argument('--nli_limit_size', default=100000, type=float, help='Learning rate')
parser.add_argument('--encoder_type', default='FIRST_LAST_AVG', type=lambda t: EncoderType[t],
choices=list(EncoderType), help='Encoder type, string name of EncoderType')
args = parser.parse_args()
Expand All @@ -124,11 +124,11 @@ def main():
train_dataset = CosentTrainDataset(model.tokenizer, train_samples, args.max_seq_length)
elif args.model_arch == 'sentencebert':
model = SentenceBertModel(model_name_or_path=args.model_name, encoder_type=args.encoder_type,
max_seq_length=args.max_seq_length)
max_seq_length=args.max_seq_length, num_classes=3)
train_dataset = TextMatchingTrainDataset(model.tokenizer, train_samples, args.max_seq_length)
else:
model = BertMatchModel(model_name_or_path=args.model_name, encoder_type=args.encoder_type,
max_seq_length=args.max_seq_length)
max_seq_length=args.max_seq_length, num_classes=3)
train_dataset = TextMatchingTrainDataset(model.tokenizer, train_samples, args.max_seq_length)
valid_dataset = TextMatchingTestDataset(model.tokenizer, valid_samples, args.max_seq_length)

Expand Down

0 comments on commit 5f7ee24

Please sign in to comment.