Skip to content

Commit

Permalink
update: add support for dynamic module loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Oct 22, 2023
1 parent efd913d commit 45e6505
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 15 deletions.
5 changes: 4 additions & 1 deletion bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
self.speakers = [item[0] for item in
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
self.symbols = symbols


# Compatible with legacy versions
self.version = process_legacy_versions(self.hps_ms)

if self.version in ["1.0", "1.0.0", "1.0.1"]:
self.symbols = symbols_legacy
self.hps_ms.model.n_layers_trans_flow = 3


elif self.version in ["1.1.0-transition"]:
self.hps_ms.model.n_layers_trans_flow = 3
Expand All @@ -35,6 +37,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):

key = f"{ModelType.BERT_VITS2.value}_v{self.version}" if self.version else ModelType.BERT_VITS2.value
self.lang = lang_dict.get(key, ["unknown"])
self.bert_handler = BertHandler(self.lang)

self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}

Expand Down Expand Up @@ -70,7 +73,7 @@ def get_text(self, text, language_str, hps):
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert = get_bert(norm_text, word2ph, language_str)
bert = self.bert_handler.get_bert(norm_text, word2ph, language_str)
del word2ph
assert bert.shape[-1] == len(phone), phone

Expand Down
10 changes: 1 addition & 9 deletions bert_vits2/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from bert_vits2.text.symbols import *
from .chinese_bert import get_bert_feature as zh_bert
from .english_bert_mock import get_bert_feature as en_bert
from .japanese_bert import get_bert_feature as ja_bert
from bert_vits2.text.bert_handler import BertHandler


def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
Expand All @@ -17,9 +15,3 @@ def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
lang_id = language_id_map[language]
lang_ids = [lang_id for i in phones]
return phones, tones, lang_ids


def get_bert(norm_text, word2ph, language):
lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert}
bert = lang_bert_func_map[language](norm_text, word2ph)
return bert
33 changes: 33 additions & 0 deletions bert_vits2/text/bert_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import importlib


class BertHandler:
_bert_functions = {}

BERT_IMPORT_MAP = {
"zh": "bert_vits2.text.chinese_bert.get_bert_feature",
"en": "bert_vits2.text.english_bert_mock.get_bert_feature",
"ja": "bert_vits2.text.japanese_bert.get_bert_feature",
}

def __init__(self, languages):
for lang in languages:
if lang not in BertHandler._bert_functions:
self.load_bert_function(lang)

def load_bert_function(self, language):
if language not in BertHandler.BERT_IMPORT_MAP:
raise ValueError(f"Unsupported language: {language}")

module_path, function_name = BertHandler.BERT_IMPORT_MAP[language].rsplit('.', 1)
module = importlib.import_module(module_path, package=__package__)
bert_function = getattr(module, function_name)

BertHandler._bert_functions[language] = bert_function

def get_bert(self, norm_text, word2ph, language):
if language not in BertHandler._bert_functions:
raise ValueError(f"BERT for {language} has not been initialized. Please initialize first.")

bert_func = BertHandler._bert_functions[language]
return bert_func(norm_text, word2ph)
24 changes: 19 additions & 5 deletions bert_vits2/text/cleaner.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
from bert_vits2.text import chinese, japanese, cleaned_text_to_sequence
import importlib
from bert_vits2.text import cleaned_text_to_sequence

language_module_map = {
'zh': chinese,
'ja': japanese
'zh': "bert_vits2.text.chinese",
'ja': "bert_vits2.text.japanese"
}

_loaded_modules = {}


def get_language_module(language):
if language not in _loaded_modules:
module_path = language_module_map.get(language)
if not module_path:
raise ValueError(f"Unsupported language: {language}")

_loaded_modules[language] = importlib.import_module(module_path)

return _loaded_modules[language]


def clean_text(text, language):
language_module = language_module_map[language]
language_module = get_language_module(language)
norm_text = language_module.text_normalize(text)
phones, tones, word2ph = language_module.g2p(norm_text)
return norm_text, phones, tones, word2ph


def clean_text_bert(text, language):
language_module = language_module_map[language]
language_module = get_language_module(language)
norm_text = language_module.text_normalize(text)
phones, tones, word2ph = language_module.g2p(norm_text)
bert = language_module.get_bert_feature(norm_text, word2ph)
Expand Down

0 comments on commit 45e6505

Please sign in to comment.