diff --git a/whisper/decoding.py b/whisper/decoding.py index ff9261e04..81cd8452b 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -778,7 +778,10 @@ def run(self, mel: Tensor) -> List[DecodingResult]: @torch.no_grad() def decode( - model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions() + model: "Whisper", + mel: Tensor, + options: DecodingOptions = DecodingOptions(), + **kwargs, ) -> Union[DecodingResult, List[DecodingResult]]: """ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). @@ -802,6 +805,9 @@ def decode( if single := mel.ndim == 2: mel = mel.unsqueeze(0) + if kwargs: + options = replace(options, **kwargs) + result = DecodingTask(model, options).run(mel) return result[0] if single else result