diff --git a/README.md b/README.md index cba6e1f..031ccf5 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,17 @@ All the supported options with their help are shown. On top of the OpenAI Whisper command line options, there are some specific options provided by CTranslate2 or whiper-ctranslate2. +## Batched inference + +Batched inference transcribes each segment in-dependently which can provide an additional 2x-4x speed increase: + + whisper-ctranslate2 inaguracio2011.mp3 --batched True + +You can additionally use the --batch_size to specify the maximum number of parallel requests to model for decoding. + +Batched inference uses Voice Activity Detection (VAD) filter and ignores the following paramters: compression_ratio_threshold, logprob_threshold, +no_speech_threshold, condition_on_previous_text, prompt_reset_on_temperature, prefix, hallucination_silence_threshold. + ## Quantization `--compute_type` option which accepts _default,auto,int8,int8_float16,int16,float16,float32_ values indicates the type of [quantization](https://opennmt.net/CTranslate2/quantization.html) to use. On CPU _int8_ will give the best performance: @@ -115,14 +126,14 @@ https://user-images.githubusercontent.com/309265/231533784-e58c4b92-e9fb-4256-b4 ## Diarization (speaker identification) -There is experimental diarization support using [`pyannote.audio`](https://github.com/pyannote/pyannote-audio) to identify speakers. At the moment, the support is a segment level. +There is experimental diarization support using [`pyannote.audio`](https://github.com/pyannote/pyannote-audio) to identify speakers. At the moment, the support is at segment level. To enable diarization you need to follow these steps: 1. Install [`pyannote.audio`](https://github.com/pyannote/pyannote-audio) with `pip install pyannote.audio` 2. Accept [`pyannote/segmentation-3.0`](https://hf.co/pyannote/segmentation-3.0) user conditions 3. Accept [`pyannote/speaker-diarization-3.1`](https://hf.co/pyannote/speaker-diarization-3.1) user conditions -4. Create access token at [`hf.co/settings/tokens`](https://hf.co/settings/tokens). +4. Create an access token at [`hf.co/settings/tokens`](https://hf.co/settings/tokens). And then execute passing the HuggingFace API token as parameter to enable diarization: diff --git a/src/whisper_ctranslate2/commandline.py b/src/whisper_ctranslate2/commandline.py index 6fa6efc..10f10e4 100644 --- a/src/whisper_ctranslate2/commandline.py +++ b/src/whisper_ctranslate2/commandline.py @@ -354,7 +354,14 @@ def read_command_line(): "--batched", type=CommandLine._str2bool, default=False, - help="Uses Batched transcription which can provide an additional 2x-3x speed increase", + help="Uses Batched transcription which can provide an additional 2x-4x speed increase", + ) + + algorithm_args.add_argument( + "--batch_size", + type=CommandLine._optional_int, + default=None, + help="When using Batched transcription the maximum number of parallel requests to model for decoding.", ) vad_args = parser.add_argument_group("VAD filter arguments") diff --git a/src/whisper_ctranslate2/transcribe.py b/src/whisper_ctranslate2/transcribe.py index 781f3a8..b8a4c24 100644 --- a/src/whisper_ctranslate2/transcribe.py +++ b/src/whisper_ctranslate2/transcribe.py @@ -111,6 +111,7 @@ def __init__( cache_directory: str, local_files_only: bool, batched: bool, + batch_size: int = None, ): self.model = WhisperModel( model_path, @@ -121,6 +122,8 @@ def __init__( download_root=cache_directory, local_files_only=local_files_only, ) + + self.batch_size = batch_size if batched: self.batched_model = BatchedInferencePipeline(model=self.model) else: @@ -144,6 +147,10 @@ def inference( model = self.model vad = options.vad_filter + batch_size = ( + {"batch_size": self.batch_size} if self.batch_size is not None else {} + ) + segments, info = model.transcribe( audio=audio, language=language, @@ -171,6 +178,7 @@ def inference( hallucination_silence_threshold=options.hallucination_silence_threshold, vad_filter=vad, vad_parameters=vad_parameters, + **batch_size, ) language_name = LANGUAGES[info.language].title() diff --git a/src/whisper_ctranslate2/version.py b/src/whisper_ctranslate2/version.py index 574c066..3d18726 100644 --- a/src/whisper_ctranslate2/version.py +++ b/src/whisper_ctranslate2/version.py @@ -1 +1 @@ -__version__ = "0.4.9" +__version__ = "0.5.0" diff --git a/src/whisper_ctranslate2/whisper_ctranslate2.py b/src/whisper_ctranslate2/whisper_ctranslate2.py index 0567d50..9cc2473 100644 --- a/src/whisper_ctranslate2/whisper_ctranslate2.py +++ b/src/whisper_ctranslate2/whisper_ctranslate2.py @@ -118,6 +118,7 @@ def main(): hf_token = args.pop("hf_token") speaker_name = args.pop("speaker_name") batched = args.pop("batched") + batch_size = args.pop("batch_size") language = get_language(language, model_directory, model) options = get_transcription_options(args) @@ -147,6 +148,10 @@ def main(): ) return + if batch_size and not batched: + sys.stderr.write("--batched_size can only be used if-- batched is True") + return + if args["max_line_count"] and not args["max_line_width"]: warnings.warn("--max_line_count has no effect without --max_line_width") @@ -218,6 +223,7 @@ def main(): cache_directory, local_files_only, batched, + batch_size, ) diarization = len(hf_token) > 0