From 45e650525cd604f782f463242ab75fedbfd81efb Mon Sep 17 00:00:00 2001 From: Artrajz <969242373@qq.com> Date: Sun, 22 Oct 2023 16:58:09 +0800 Subject: [PATCH] update: add support for dynamic module loading --- bert_vits2/bert_vits2.py | 5 ++++- bert_vits2/text/__init__.py | 10 +--------- bert_vits2/text/bert_handler.py | 33 +++++++++++++++++++++++++++++++++ bert_vits2/text/cleaner.py | 24 +++++++++++++++++++----- 4 files changed, 57 insertions(+), 15 deletions(-) create mode 100644 bert_vits2/text/bert_handler.py diff --git a/bert_vits2/bert_vits2.py b/bert_vits2/bert_vits2.py index d1d58b3..88c7ee9 100644 --- a/bert_vits2/bert_vits2.py +++ b/bert_vits2/bert_vits2.py @@ -19,6 +19,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): self.speakers = [item[0] for item in sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])] self.symbols = symbols + # Compatible with legacy versions self.version = process_legacy_versions(self.hps_ms) @@ -26,6 +27,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): if self.version in ["1.0", "1.0.0", "1.0.1"]: self.symbols = symbols_legacy self.hps_ms.model.n_layers_trans_flow = 3 + elif self.version in ["1.1.0-transition"]: self.hps_ms.model.n_layers_trans_flow = 3 @@ -35,6 +37,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): key = f"{ModelType.BERT_VITS2.value}_v{self.version}" if self.version else ModelType.BERT_VITS2.value self.lang = lang_dict.get(key, ["unknown"]) + self.bert_handler = BertHandler(self.lang) self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} @@ -70,7 +73,7 @@ def get_text(self, text, language_str, hps): for i in range(len(word2ph)): word2ph[i] = word2ph[i] * 2 word2ph[0] += 1 - bert = get_bert(norm_text, word2ph, language_str) + bert = self.bert_handler.get_bert(norm_text, word2ph, language_str) del word2ph assert bert.shape[-1] == len(phone), phone diff --git a/bert_vits2/text/__init__.py b/bert_vits2/text/__init__.py index 91aa85f..550135f 100644 --- a/bert_vits2/text/__init__.py +++ b/bert_vits2/text/__init__.py @@ -1,7 +1,5 @@ from bert_vits2.text.symbols import * -from .chinese_bert import get_bert_feature as zh_bert -from .english_bert_mock import get_bert_feature as en_bert -from .japanese_bert import get_bert_feature as ja_bert +from bert_vits2.text.bert_handler import BertHandler def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id): @@ -17,9 +15,3 @@ def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id): lang_id = language_id_map[language] lang_ids = [lang_id for i in phones] return phones, tones, lang_ids - - -def get_bert(norm_text, word2ph, language): - lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert} - bert = lang_bert_func_map[language](norm_text, word2ph) - return bert diff --git a/bert_vits2/text/bert_handler.py b/bert_vits2/text/bert_handler.py new file mode 100644 index 0000000..fb5c790 --- /dev/null +++ b/bert_vits2/text/bert_handler.py @@ -0,0 +1,33 @@ +import importlib + + +class BertHandler: + _bert_functions = {} + + BERT_IMPORT_MAP = { + "zh": "bert_vits2.text.chinese_bert.get_bert_feature", + "en": "bert_vits2.text.english_bert_mock.get_bert_feature", + "ja": "bert_vits2.text.japanese_bert.get_bert_feature", + } + + def __init__(self, languages): + for lang in languages: + if lang not in BertHandler._bert_functions: + self.load_bert_function(lang) + + def load_bert_function(self, language): + if language not in BertHandler.BERT_IMPORT_MAP: + raise ValueError(f"Unsupported language: {language}") + + module_path, function_name = BertHandler.BERT_IMPORT_MAP[language].rsplit('.', 1) + module = importlib.import_module(module_path, package=__package__) + bert_function = getattr(module, function_name) + + BertHandler._bert_functions[language] = bert_function + + def get_bert(self, norm_text, word2ph, language): + if language not in BertHandler._bert_functions: + raise ValueError(f"BERT for {language} has not been initialized. Please initialize first.") + + bert_func = BertHandler._bert_functions[language] + return bert_func(norm_text, word2ph) diff --git a/bert_vits2/text/cleaner.py b/bert_vits2/text/cleaner.py index e426eac..d8bc51d 100644 --- a/bert_vits2/text/cleaner.py +++ b/bert_vits2/text/cleaner.py @@ -1,20 +1,34 @@ -from bert_vits2.text import chinese, japanese, cleaned_text_to_sequence +import importlib +from bert_vits2.text import cleaned_text_to_sequence language_module_map = { - 'zh': chinese, - 'ja': japanese + 'zh': "bert_vits2.text.chinese", + 'ja': "bert_vits2.text.japanese" } +_loaded_modules = {} + + +def get_language_module(language): + if language not in _loaded_modules: + module_path = language_module_map.get(language) + if not module_path: + raise ValueError(f"Unsupported language: {language}") + + _loaded_modules[language] = importlib.import_module(module_path) + + return _loaded_modules[language] + def clean_text(text, language): - language_module = language_module_map[language] + language_module = get_language_module(language) norm_text = language_module.text_normalize(text) phones, tones, word2ph = language_module.g2p(norm_text) return norm_text, phones, tones, word2ph def clean_text_bert(text, language): - language_module = language_module_map[language] + language_module = get_language_module(language) norm_text = language_module.text_normalize(text) phones, tones, word2ph = language_module.g2p(norm_text) bert = language_module.get_bert_feature(norm_text, word2ph)