diff --git a/open_dubbing/speech_to_text.py b/open_dubbing/speech_to_text.py index 959b834..66f6ceb 100644 --- a/open_dubbing/speech_to_text.py +++ b/open_dubbing/speech_to_text.py @@ -13,18 +13,17 @@ # limitations under the License. from typing import Mapping, Sequence -from faster_whisper import WhisperModel import logging from iso639 import Lang +from abc import ABC, abstractmethod -class SpeechToText: +class SpeechToText(ABC): def __init__(self, device="cpu", cpu_threads=0): self.model = None self.device = device self.cpu_threads = cpu_threads - logging.getLogger("faster_whisper").setLevel(logging.ERROR) @property def model(self): @@ -34,41 +33,27 @@ def model(self): def model(self, value): self._model = value + @abstractmethod def load_model(self): - self._model = WhisperModel( - model_size_or_path="medium", - device=self.device, - cpu_threads=self.cpu_threads, - compute_type="float16" if self.device == "cuda" else "int8", - ) + pass + @abstractmethod def get_languages(self): - iso_639_3 = [] - for language in self.model.supported_languages: - if language == "jw": - language = "jv" - - o = Lang(language) - pt3 = o.pt3 - iso_639_3.append(pt3) - return iso_639_3 + pass def _get_iso_639_1(self, iso_639_3: str): o = Lang(iso_639_3) iso_639_1 = o.pt1 return iso_639_1 + @abstractmethod def _transcribe( self, *, vocals_filepath: str, source_language_iso_639_1: str, ) -> str: - segments, _ = self.model.transcribe( - vocals_filepath, - source_language_iso_639_1, - ) - return " ".join(segment.text for segment in segments) + pass def transcribe_audio_chunks( self, diff --git a/open_dubbing/speech_to_text_faster_whisper.py b/open_dubbing/speech_to_text_faster_whisper.py new file mode 100644 index 0000000..cea9502 --- /dev/null +++ b/open_dubbing/speech_to_text_faster_whisper.py @@ -0,0 +1,58 @@ +# Copyright 2024 Jordi Mas i HerÇıandez +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from faster_whisper import WhisperModel +import logging +from iso639 import Lang +from open_dubbing.speech_to_text import SpeechToText + + +class SpeechToTextFasterWhisper(SpeechToText): + + def __init__(self, device="cpu", cpu_threads=0): + self.model = None + self.device = device + self.cpu_threads = cpu_threads + logging.getLogger("faster_whisper").setLevel(logging.ERROR) + + def load_model(self): + self._model = WhisperModel( + model_size_or_path="medium", + device=self.device, + cpu_threads=self.cpu_threads, + compute_type="float16" if self.device == "cuda" else "int8", + ) + + def get_languages(self): + iso_639_3 = [] + for language in self.model.supported_languages: + if language == "jw": + language = "jv" + + o = Lang(language) + pt3 = o.pt3 + iso_639_3.append(pt3) + return iso_639_3 + + def _transcribe( + self, + *, + vocals_filepath: str, + source_language_iso_639_1: str, + ) -> str: + segments, _ = self.model.transcribe( + vocals_filepath, + source_language_iso_639_1, + ) + return " ".join(segment.text for segment in segments) diff --git a/r.sh b/r.sh index 4caaf02..9ee6ed9 100755 --- a/r.sh +++ b/r.sh @@ -4,8 +4,8 @@ rm -r -f output/ declare -a inputs=("videos/cameratips.mp4" "videos/jordi.mp4" "videos/jobinterview.mp4" "videos/jordimaskudosallhands.mp4" "videos/michael.mp4" "videos/simplicty.mp4") declare -a target_languages=("cat") # Catalan (cat) and French (fra) -declare -a tts_list=("mms" "coqui") -declare -a inputs=("videos/jordimaskudosallhands.mp4" "videos/jobinterview.mp4") +declare -a tts_list=("coqui") +declare -a inputs=("videos/jordi.mp4") for tts in "${tts_list[@]}"; do for input_file in "${inputs[@]}"; do @@ -14,7 +14,7 @@ for tts in "${tts_list[@]}"; do output_directory="output/$(basename "${input_file%.*}").${language}.${tts}/" # Run the dubbing command - open-dubbing \ + KMP_DUPLICATE_LIB_OK=TRUE open-dubbing \ --input_file "$input_file" \ --output_directory="$output_directory" \ --source_language=eng \ diff --git a/tests/speech_to_text_test.py b/tests/speech_to_text_test.py index c003b57..78f0c44 100644 --- a/tests/speech_to_text_test.py +++ b/tests/speech_to_text_test.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for utility functions in SpeechToText().py.""" - from collections import namedtuple import tempfile from unittest.mock import MagicMock from faster_whisper import WhisperModel from moviepy.audio.AudioClip import AudioArrayClip import numpy as np -from open_dubbing.speech_to_text import SpeechToText +from open_dubbing.speech_to_text_faster_whisper import SpeechToTextFasterWhisper import pytest import os @@ -45,7 +43,7 @@ def test_transcribe(self): mock_model = MagicMock(spec=WhisperModel) Segment = namedtuple("Segment", ["text"]) mock_model.transcribe.return_value = [Segment(text="Test.")], None - spt = SpeechToText() + spt = SpeechToTextFasterWhisper() spt.model = mock_model transcribed_text = spt._transcribe( vocals_filepath=self.silence_audio, @@ -68,7 +66,7 @@ def test_transcribe_chunks(self, no_dubbing_phrases, expected_for_dubbing): ], None utterance_metadata = [dict(path=self.silence_audio, start=0.0, end=5.0)] source_language = "en" - spt = SpeechToText() + spt = SpeechToTextFasterWhisper() spt.model = mock_model transcribed_audio_chunks = spt.transcribe_audio_chunks( utterance_metadata=utterance_metadata, @@ -123,7 +121,7 @@ def test_add_speaker_info(self): "path": "path/to/file.mp3", }, ] - result = SpeechToText().add_speaker_info(utterance_metadata, speaker_info) + result = SpeechToTextFasterWhisper().add_speaker_info(utterance_metadata, speaker_info) assert result == expected_result def test_add_speaker_info_unequal_lengths(self): @@ -136,4 +134,4 @@ def test_add_speaker_info_unequal_lengths(self): Exception, match="The length of 'utterance_metadata' and 'speaker_info' must be the same.", ): - SpeechToText().add_speaker_info(utterance_metadata, speaker_info) + SpeechToTextFasterWhisper().add_speaker_info(utterance_metadata, speaker_info)