Skip to content

Commit

Permalink
Add --live_input_device_sample_rate parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
karasjoh000 authored and jordimas committed Dec 12, 2024
1 parent 4e7c914 commit 4d8a052
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/whisper_ctranslate2/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,4 +446,11 @@ def read_command_line():
help="Set live stream input device ID (see python -m sounddevice for a list)",
)

live_args.add_argument(
"--live_input_device_sample_rate",
type=int,
default=16000,
help="Set live sample rate of input device",
)

return parser.parse_args().__dict__
13 changes: 9 additions & 4 deletions src/whisper_ctranslate2/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .transcribe import Transcribe, TranscriptionOptions

SampleRate = 16000 # Stream device recording frequency per second
BlockSize = 30 # Block size in milliseconds
Vocals = [50, 1000] # Frequency range to detect sounds that could be speech
EndBlocks = 33 * 2 # Number of blocks to wait before sending (30 ms is block)
Expand Down Expand Up @@ -38,6 +37,7 @@ def __init__(
verbose: bool,
threshold: float,
input_device: int,
input_device_sample_rate: int,
options: TranscriptionOptions,
):
self.model_path = model_path
Expand All @@ -52,6 +52,7 @@ def __init__(
self.verbose = verbose
self.threshold = threshold
self.input_device = input_device
self.input_device_sample_rate = input_device_sample_rate
self.options = options

self.running = True
Expand All @@ -71,7 +72,11 @@ def force_not_available_exception():
raise (sounddevice_exception)

def _is_there_voice(self, indata, frames):
freq = np.argmax(np.abs(np.fft.rfft(indata[:, 0]))) * SampleRate / frames
freq = (
np.argmax(np.abs(np.fft.rfft(indata[:, 0])))
* self.input_device_sample_rate
/ frames
)
volume = np.sqrt(np.mean(indata**2))

return volume > self.threshold and Vocals[0] <= freq <= Vocals[1]
Expand Down Expand Up @@ -158,8 +163,8 @@ def listen(self):
with sd.InputStream(
channels=1,
callback=self.callback,
blocksize=int(SampleRate * BlockSize / 1000),
samplerate=SampleRate,
blocksize=int(self.input_device_sample_rate * BlockSize / 1000),
samplerate=self.input_device_sample_rate,
device=self.input_device,
):
while self.running:
Expand Down
2 changes: 2 additions & 0 deletions src/whisper_ctranslate2/whisper_ctranslate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def main():
local_files_only: bool = args.pop("local_files_only")
live_volume_threshold: float = args.pop("live_volume_threshold")
live_input_device: int = args.pop("live_input_device")
live_input_device_sample_rate: int = args.pop("live_input_device_sample_rate")
hf_token = args.pop("hf_token")
speaker_name = args.pop("speaker_name")
batched = args.pop("batched")
Expand Down Expand Up @@ -207,6 +208,7 @@ def main():
verbose,
live_volume_threshold,
live_input_device,
live_input_device_sample_rate,
options,
).inference()

Expand Down
1 change: 1 addition & 0 deletions tests/testlive.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_constructor(self):
threshold=0.2,
input_device=0,
options=None,
input_device_sample_rate=16000,
)
self.assertNotEqual(None, live)

Expand Down

0 comments on commit 4d8a052

Please sign in to comment.