Skip to content

Commit

Permalink
When using Batched transcription the maximum number of parallel reque…
Browse files Browse the repository at this point in the history
…sts to model for decoding
  • Loading branch information
jordimas committed Dec 1, 2024
1 parent d428596 commit 7b2bd9c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 1 deletion.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ On top of the OpenAI Whisper command line options, there are some specific optio
Batched inference transcribes each segment in-dependently which can provide an additional 2x-4x speed increase:

whisper-ctranslate2 inaguracio2011.mp3 --batched True

You can additionally use the --batch_size to specify the maximum number of parallel requests to model for decoding.

Batched inference uses Voice Activity Detection (VAD) filter.
Batched inference uses Voice Activity Detection (VAD) filter and ignores the following paramters: compression_ratio_threshold, logprob_threshold,
no_speech_threshold, condition_on_previous_text, prompt_reset_on_temperature, prefix, hallucination_silence_threshold.

## Quantization

Expand Down
7 changes: 7 additions & 0 deletions src/whisper_ctranslate2/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ def read_command_line():
help="Uses Batched transcription which can provide an additional 2x-4x speed increase",
)

algorithm_args.add_argument(
"--batch_size",
type=CommandLine._optional_int,
default=None,
help="When using Batched transcription the maximum number of parallel requests to model for decoding.",
)

vad_args = parser.add_argument_group("VAD filter arguments")

vad_args.add_argument(
Expand Down
8 changes: 8 additions & 0 deletions src/whisper_ctranslate2/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
cache_directory: str,
local_files_only: bool,
batched: bool,
batch_size: int = None,
):
self.model = WhisperModel(
model_path,
Expand All @@ -121,6 +122,8 @@ def __init__(
download_root=cache_directory,
local_files_only=local_files_only,
)

self.batch_size = batch_size
if batched:
self.batched_model = BatchedInferencePipeline(model=self.model)
else:
Expand All @@ -144,6 +147,10 @@ def inference(
model = self.model
vad = options.vad_filter

batch_size = (
{"batch_size": self.batch_size} if self.batch_size is not None else {}
)

segments, info = model.transcribe(
audio=audio,
language=language,
Expand Down Expand Up @@ -171,6 +178,7 @@ def inference(
hallucination_silence_threshold=options.hallucination_silence_threshold,
vad_filter=vad,
vad_parameters=vad_parameters,
**batch_size,
)

language_name = LANGUAGES[info.language].title()
Expand Down
6 changes: 6 additions & 0 deletions src/whisper_ctranslate2/whisper_ctranslate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def main():
hf_token = args.pop("hf_token")
speaker_name = args.pop("speaker_name")
batched = args.pop("batched")
batch_size = args.pop("batch_size")

language = get_language(language, model_directory, model)
options = get_transcription_options(args)
Expand Down Expand Up @@ -146,6 +147,10 @@ def main():
)
return

if batch_size and not batched:
sys.stderr.write("--batched_size can only be used if-- batched is True")
return

if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")

Expand Down Expand Up @@ -216,6 +221,7 @@ def main():
cache_directory,
local_files_only,
batched,
batch_size,
)

diarization = len(hf_token) > 0
Expand Down

0 comments on commit 7b2bd9c

Please sign in to comment.