Skip to content

Commit

Permalink
Size
Browse files Browse the repository at this point in the history
  • Loading branch information
jordimas committed Dec 1, 2024
1 parent 5746285 commit 913086a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
7 changes: 7 additions & 0 deletions src/whisper_ctranslate2/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ def read_command_line():
help="Uses Batched transcription which can provide an additional 2x-4x speed increase",
)

algorithm_args.add_argument(
"--batch_size",
type=CommandLine._optional_int,
default=None,
help="When using Batched transcriptionthe maximum number of parallel requests to model for decoding.",
)

vad_args = parser.add_argument_group("VAD filter arguments")

vad_args.add_argument(
Expand Down
8 changes: 7 additions & 1 deletion src/whisper_ctranslate2/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
cache_directory: str,
local_files_only: bool,
batched: bool,
batch_size: int = None,
):
self.model = WhisperModel(
model_path,
Expand All @@ -122,7 +123,12 @@ def __init__(
local_files_only=local_files_only,
)
if batched:
self.batched_model = BatchedInferencePipeline(model=self.model)
if batch_size:
self.batched_model = BatchedInferencePipeline(model=self.model)
else:
self.batched_model = BatchedInferencePipeline(
model=self.model, batch_size=batch_size
)
else:
self.batched_model = None

Expand Down
6 changes: 6 additions & 0 deletions src/whisper_ctranslate2/whisper_ctranslate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def main():
hf_token = args.pop("hf_token")
speaker_name = args.pop("speaker_name")
batched = args.pop("batched")
batch_size = args.pop("batch_size")

language = get_language(language, model_directory, model)
options = get_transcription_options(args)
Expand Down Expand Up @@ -146,6 +147,10 @@ def main():
)
return

if batch_size and not batched:
sys.stderr.write("--batched_size can only be used if-- batched is True")
return

if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")

Expand Down Expand Up @@ -216,6 +221,7 @@ def main():
cache_directory,
local_files_only,
batched,
batch_size,
)

diarization = len(hf_token) > 0
Expand Down

0 comments on commit 913086a

Please sign in to comment.