Skip to content

Commit

Permalink
Fix Wav2Vec2 async rebatched (#168)
Browse files Browse the repository at this point in the history
Co-authored-by: Jinchen Ge <jincheng@graphcore.ai>
  • Loading branch information
gejinchen and gejinchen authored Aug 18, 2022
1 parent 3947cec commit dd73165
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/speech-pretraining/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,12 @@ def prepare_dataset(batch):
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)

# Create a new model under no_grad() just for the collator to avoid causing multiprocessing error.
with torch.no_grad():
model_collator = AutoModelForPreTraining.from_config(config)
# Instantiate custom data collator
data_collator = DataCollatorForWav2Vec2Pretraining(
model=model,
model=model_collator,
feature_extractor=feature_extractor,
reducer_keep_factor=model_args.mask_time_prob * (1.0 - model_args.crop_aggression),
)
Expand Down

0 comments on commit dd73165

Please sign in to comment.