diff --git a/README.md b/README.md index 18d77ec..031ccf5 100644 --- a/README.md +++ b/README.md @@ -71,8 +71,11 @@ On top of the OpenAI Whisper command line options, there are some specific optio Batched inference transcribes each segment in-dependently which can provide an additional 2x-4x speed increase: whisper-ctranslate2 inaguracio2011.mp3 --batched True + +You can additionally use the --batch_size to specify the maximum number of parallel requests to model for decoding. -Batched inference uses Voice Activity Detection (VAD) filter. +Batched inference uses Voice Activity Detection (VAD) filter and ignores the following paramters: compression_ratio_threshold, logprob_threshold, +no_speech_threshold, condition_on_previous_text, prompt_reset_on_temperature, prefix, hallucination_silence_threshold. ## Quantization diff --git a/src/whisper_ctranslate2/commandline.py b/src/whisper_ctranslate2/commandline.py index 438e380..1564c47 100644 --- a/src/whisper_ctranslate2/commandline.py +++ b/src/whisper_ctranslate2/commandline.py @@ -357,6 +357,13 @@ def read_command_line(): help="Uses Batched transcription which can provide an additional 2x-4x speed increase", ) + algorithm_args.add_argument( + "--batch_size", + type=CommandLine._optional_int, + default=None, + help="When using Batched transcription the maximum number of parallel requests to model for decoding.", + ) + vad_args = parser.add_argument_group("VAD filter arguments") vad_args.add_argument( diff --git a/src/whisper_ctranslate2/transcribe.py b/src/whisper_ctranslate2/transcribe.py index 781f3a8..b8a4c24 100644 --- a/src/whisper_ctranslate2/transcribe.py +++ b/src/whisper_ctranslate2/transcribe.py @@ -111,6 +111,7 @@ def __init__( cache_directory: str, local_files_only: bool, batched: bool, + batch_size: int = None, ): self.model = WhisperModel( model_path, @@ -121,6 +122,8 @@ def __init__( download_root=cache_directory, local_files_only=local_files_only, ) + + self.batch_size = batch_size if batched: self.batched_model = BatchedInferencePipeline(model=self.model) else: @@ -144,6 +147,10 @@ def inference( model = self.model vad = options.vad_filter + batch_size = ( + {"batch_size": self.batch_size} if self.batch_size is not None else {} + ) + segments, info = model.transcribe( audio=audio, language=language, @@ -171,6 +178,7 @@ def inference( hallucination_silence_threshold=options.hallucination_silence_threshold, vad_filter=vad, vad_parameters=vad_parameters, + **batch_size, ) language_name = LANGUAGES[info.language].title() diff --git a/src/whisper_ctranslate2/whisper_ctranslate2.py b/src/whisper_ctranslate2/whisper_ctranslate2.py index 31c5ca2..2a507cb 100644 --- a/src/whisper_ctranslate2/whisper_ctranslate2.py +++ b/src/whisper_ctranslate2/whisper_ctranslate2.py @@ -117,6 +117,7 @@ def main(): hf_token = args.pop("hf_token") speaker_name = args.pop("speaker_name") batched = args.pop("batched") + batch_size = args.pop("batch_size") language = get_language(language, model_directory, model) options = get_transcription_options(args) @@ -146,6 +147,10 @@ def main(): ) return + if batch_size and not batched: + sys.stderr.write("--batched_size can only be used if-- batched is True") + return + if args["max_line_count"] and not args["max_line_width"]: warnings.warn("--max_line_count has no effect without --max_line_width") @@ -216,6 +221,7 @@ def main(): cache_directory, local_files_only, batched, + batch_size, ) diarization = len(hf_token) > 0