From f94d9fd92afb64b10f03e7df5b0d3800057525e6 Mon Sep 17 00:00:00 2001 From: Artrajz <969242373@qq.com> Date: Tue, 10 Oct 2023 10:36:37 +0800 Subject: [PATCH] fix: classify language add self.lang --- bert_vits2/bert_vits2.py | 9 +++++++-- vits/vits.py | 5 ++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/bert_vits2/bert_vits2.py b/bert_vits2/bert_vits2.py index b54eb0a..d1d58b3 100644 --- a/bert_vits2/bert_vits2.py +++ b/bert_vits2/bert_vits2.py @@ -7,6 +7,7 @@ from bert_vits2.text import * from bert_vits2.text.cleaner import clean_text from bert_vits2.utils import process_legacy_versions +from contants import ModelType from utils import classify_language, get_hparams_from_file, lang_dict from utils.sentence import sentence_split_and_markup, cut @@ -18,7 +19,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): self.speakers = [item[0] for item in sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])] self.symbols = symbols - + # Compatible with legacy versions self.version = process_legacy_versions(self.hps_ms) @@ -32,6 +33,9 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): elif self.version in ["1.1", "1.1.0", "1.1.1"]: self.hps_ms.model.n_layers_trans_flow = 6 + key = f"{ModelType.BERT_VITS2.value}_v{self.version}" if self.version else ModelType.BERT_VITS2.value + self.lang = lang_dict.get(key, ["unknown"]) + self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} self.net_g = SynthesizerTrn( @@ -115,7 +119,8 @@ def get_audio(self, voice, auto_break=False): max = voice.get("max", 50) # sentence_list = sentence_split_and_markup(text, max, "ZH", ["zh"]) if lang == "auto": - lang = classify_language(text, target_languages=lang_dict["bert_vits2"]) + lang = classify_language(text, target_languages=self.lang) + sentence_list = cut(text, max) audios = [] for sentence in sentence_list: diff --git a/vits/vits.py b/vits/vits.py index 9eb6046..ddc5468 100644 --- a/vits/vits.py +++ b/vits/vits.py @@ -5,7 +5,7 @@ from torch import no_grad, LongTensor, inference_mode, FloatTensor import utils from contants import ModelType -from utils import get_hparams_from_file +from utils import get_hparams_from_file, lang_dict from utils.sentence import sentence_split_and_markup from vits import commons from vits.mel_processing import spectrogram_torch @@ -39,6 +39,9 @@ def __init__(self, model, config, additional_model=None, model_type=None, device _ = self.net_g_ms.eval() self.device = device + key = getattr(self.hps_ms.data, "text_cleaners", ["none"])[0] + self.lang = lang_dict.get(key, ["unknown"]) + # load model self.load_model(model, additional_model)