Skip to content

Commit

Permalink
Initial decoupling
Browse files Browse the repository at this point in the history
  • Loading branch information
jordimas committed Sep 16, 2024
1 parent 4f7c2c3 commit ed731da
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 33 deletions.
31 changes: 8 additions & 23 deletions open_dubbing/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
# limitations under the License.

from typing import Mapping, Sequence
from faster_whisper import WhisperModel
import logging
from iso639 import Lang
from abc import ABC, abstractmethod


class SpeechToText:
class SpeechToText(ABC):

def __init__(self, device="cpu", cpu_threads=0):
self.model = None
self.device = device
self.cpu_threads = cpu_threads
logging.getLogger("faster_whisper").setLevel(logging.ERROR)

@property
def model(self):
Expand All @@ -34,41 +33,27 @@ def model(self):
def model(self, value):
self._model = value

@abstractmethod
def load_model(self):
self._model = WhisperModel(
model_size_or_path="medium",
device=self.device,
cpu_threads=self.cpu_threads,
compute_type="float16" if self.device == "cuda" else "int8",
)
pass

@abstractmethod
def get_languages(self):
iso_639_3 = []
for language in self.model.supported_languages:
if language == "jw":
language = "jv"

o = Lang(language)
pt3 = o.pt3
iso_639_3.append(pt3)
return iso_639_3
pass

def _get_iso_639_1(self, iso_639_3: str):
o = Lang(iso_639_3)
iso_639_1 = o.pt1
return iso_639_1

@abstractmethod
def _transcribe(
self,
*,
vocals_filepath: str,
source_language_iso_639_1: str,
) -> str:
segments, _ = self.model.transcribe(
vocals_filepath,
source_language_iso_639_1,
)
return " ".join(segment.text for segment in segments)
pass

def transcribe_audio_chunks(
self,
Expand Down
58 changes: 58 additions & 0 deletions open_dubbing/speech_to_text_faster_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 Jordi Mas i Herǹandez <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from faster_whisper import WhisperModel
import logging
from iso639 import Lang
from open_dubbing.speech_to_text import SpeechToText


class SpeechToTextFasterWhisper(SpeechToText):

def __init__(self, device="cpu", cpu_threads=0):
self.model = None
self.device = device
self.cpu_threads = cpu_threads
logging.getLogger("faster_whisper").setLevel(logging.ERROR)

def load_model(self):
self._model = WhisperModel(
model_size_or_path="medium",
device=self.device,
cpu_threads=self.cpu_threads,
compute_type="float16" if self.device == "cuda" else "int8",
)

def get_languages(self):
iso_639_3 = []
for language in self.model.supported_languages:
if language == "jw":
language = "jv"

o = Lang(language)
pt3 = o.pt3
iso_639_3.append(pt3)
return iso_639_3

def _transcribe(
self,
*,
vocals_filepath: str,
source_language_iso_639_1: str,
) -> str:
segments, _ = self.model.transcribe(
vocals_filepath,
source_language_iso_639_1,
)
return " ".join(segment.text for segment in segments)
6 changes: 3 additions & 3 deletions r.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ rm -r -f output/

declare -a inputs=("videos/cameratips.mp4" "videos/jordi.mp4" "videos/jobinterview.mp4" "videos/jordimaskudosallhands.mp4" "videos/michael.mp4" "videos/simplicty.mp4")
declare -a target_languages=("cat") # Catalan (cat) and French (fra)
declare -a tts_list=("mms" "coqui")
declare -a inputs=("videos/jordimaskudosallhands.mp4" "videos/jobinterview.mp4")
declare -a tts_list=("coqui")
declare -a inputs=("videos/jordi.mp4")

for tts in "${tts_list[@]}"; do
for input_file in "${inputs[@]}"; do
Expand All @@ -14,7 +14,7 @@ for tts in "${tts_list[@]}"; do
output_directory="output/$(basename "${input_file%.*}").${language}.${tts}/"

# Run the dubbing command
open-dubbing \
KMP_DUPLICATE_LIB_OK=TRUE open-dubbing \
--input_file "$input_file" \
--output_directory="$output_directory" \
--source_language=eng \
Expand Down
12 changes: 5 additions & 7 deletions tests/speech_to_text_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for utility functions in SpeechToText().py."""

from collections import namedtuple
import tempfile
from unittest.mock import MagicMock
from faster_whisper import WhisperModel
from moviepy.audio.AudioClip import AudioArrayClip
import numpy as np
from open_dubbing.speech_to_text import SpeechToText
from open_dubbing.speech_to_text_faster_whisper import SpeechToTextFasterWhisper
import pytest
import os

Expand All @@ -45,7 +43,7 @@ def test_transcribe(self):
mock_model = MagicMock(spec=WhisperModel)
Segment = namedtuple("Segment", ["text"])
mock_model.transcribe.return_value = [Segment(text="Test.")], None
spt = SpeechToText()
spt = SpeechToTextFasterWhisper()
spt.model = mock_model
transcribed_text = spt._transcribe(
vocals_filepath=self.silence_audio,
Expand All @@ -68,7 +66,7 @@ def test_transcribe_chunks(self, no_dubbing_phrases, expected_for_dubbing):
], None
utterance_metadata = [dict(path=self.silence_audio, start=0.0, end=5.0)]
source_language = "en"
spt = SpeechToText()
spt = SpeechToTextFasterWhisper()
spt.model = mock_model
transcribed_audio_chunks = spt.transcribe_audio_chunks(
utterance_metadata=utterance_metadata,
Expand Down Expand Up @@ -123,7 +121,7 @@ def test_add_speaker_info(self):
"path": "path/to/file.mp3",
},
]
result = SpeechToText().add_speaker_info(utterance_metadata, speaker_info)
result = SpeechToTextFasterWhisper().add_speaker_info(utterance_metadata, speaker_info)
assert result == expected_result

def test_add_speaker_info_unequal_lengths(self):
Expand All @@ -136,4 +134,4 @@ def test_add_speaker_info_unequal_lengths(self):
Exception,
match="The length of 'utterance_metadata' and 'speaker_info' must be the same.",
):
SpeechToText().add_speaker_info(utterance_metadata, speaker_info)
SpeechToTextFasterWhisper().add_speaker_info(utterance_metadata, speaker_info)

0 comments on commit ed731da

Please sign in to comment.