From 4d8a05206908b1c4f4d55a8432a201a27958cc2e Mon Sep 17 00:00:00 2001 From: John Karasev Date: Thu, 12 Dec 2024 18:26:14 +0100 Subject: [PATCH] Add --live_input_device_sample_rate parameter --- src/whisper_ctranslate2/commandline.py | 7 +++++++ src/whisper_ctranslate2/live.py | 13 +++++++++---- src/whisper_ctranslate2/whisper_ctranslate2.py | 2 ++ tests/testlive.py | 1 + 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/whisper_ctranslate2/commandline.py b/src/whisper_ctranslate2/commandline.py index 1564c47..10f10e4 100644 --- a/src/whisper_ctranslate2/commandline.py +++ b/src/whisper_ctranslate2/commandline.py @@ -446,4 +446,11 @@ def read_command_line(): help="Set live stream input device ID (see python -m sounddevice for a list)", ) + live_args.add_argument( + "--live_input_device_sample_rate", + type=int, + default=16000, + help="Set live sample rate of input device", + ) + return parser.parse_args().__dict__ diff --git a/src/whisper_ctranslate2/live.py b/src/whisper_ctranslate2/live.py index ab2739e..758fed8 100644 --- a/src/whisper_ctranslate2/live.py +++ b/src/whisper_ctranslate2/live.py @@ -6,7 +6,6 @@ from .transcribe import Transcribe, TranscriptionOptions -SampleRate = 16000 # Stream device recording frequency per second BlockSize = 30 # Block size in milliseconds Vocals = [50, 1000] # Frequency range to detect sounds that could be speech EndBlocks = 33 * 2 # Number of blocks to wait before sending (30 ms is block) @@ -38,6 +37,7 @@ def __init__( verbose: bool, threshold: float, input_device: int, + input_device_sample_rate: int, options: TranscriptionOptions, ): self.model_path = model_path @@ -52,6 +52,7 @@ def __init__( self.verbose = verbose self.threshold = threshold self.input_device = input_device + self.input_device_sample_rate = input_device_sample_rate self.options = options self.running = True @@ -71,7 +72,11 @@ def force_not_available_exception(): raise (sounddevice_exception) def _is_there_voice(self, indata, frames): - freq = np.argmax(np.abs(np.fft.rfft(indata[:, 0]))) * SampleRate / frames + freq = ( + np.argmax(np.abs(np.fft.rfft(indata[:, 0]))) + * self.input_device_sample_rate + / frames + ) volume = np.sqrt(np.mean(indata**2)) return volume > self.threshold and Vocals[0] <= freq <= Vocals[1] @@ -158,8 +163,8 @@ def listen(self): with sd.InputStream( channels=1, callback=self.callback, - blocksize=int(SampleRate * BlockSize / 1000), - samplerate=SampleRate, + blocksize=int(self.input_device_sample_rate * BlockSize / 1000), + samplerate=self.input_device_sample_rate, device=self.input_device, ): while self.running: diff --git a/src/whisper_ctranslate2/whisper_ctranslate2.py b/src/whisper_ctranslate2/whisper_ctranslate2.py index 2a507cb..9cc2473 100644 --- a/src/whisper_ctranslate2/whisper_ctranslate2.py +++ b/src/whisper_ctranslate2/whisper_ctranslate2.py @@ -114,6 +114,7 @@ def main(): local_files_only: bool = args.pop("local_files_only") live_volume_threshold: float = args.pop("live_volume_threshold") live_input_device: int = args.pop("live_input_device") + live_input_device_sample_rate: int = args.pop("live_input_device_sample_rate") hf_token = args.pop("hf_token") speaker_name = args.pop("speaker_name") batched = args.pop("batched") @@ -207,6 +208,7 @@ def main(): verbose, live_volume_threshold, live_input_device, + live_input_device_sample_rate, options, ).inference() diff --git a/tests/testlive.py b/tests/testlive.py index 9192a32..b6233a2 100644 --- a/tests/testlive.py +++ b/tests/testlive.py @@ -19,6 +19,7 @@ def test_constructor(self): threshold=0.2, input_device=0, options=None, + input_device_sample_rate=16000, ) self.assertNotEqual(None, live)