Skip to content

Commit

Permalink
fix: classify language
Browse files Browse the repository at this point in the history
add self.lang
  • Loading branch information
Artrajz committed Oct 10, 2023
1 parent b5d3356 commit f94d9fd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
9 changes: 7 additions & 2 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bert_vits2.text import *
from bert_vits2.text.cleaner import clean_text
from bert_vits2.utils import process_legacy_versions
from contants import ModelType
from utils import classify_language, get_hparams_from_file, lang_dict
from utils.sentence import sentence_split_and_markup, cut

Expand All @@ -18,7 +19,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
self.speakers = [item[0] for item in
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
self.symbols = symbols

# Compatible with legacy versions
self.version = process_legacy_versions(self.hps_ms)

Expand All @@ -32,6 +33,9 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
elif self.version in ["1.1", "1.1.0", "1.1.1"]:
self.hps_ms.model.n_layers_trans_flow = 6

key = f"{ModelType.BERT_VITS2.value}_v{self.version}" if self.version else ModelType.BERT_VITS2.value
self.lang = lang_dict.get(key, ["unknown"])

self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}

self.net_g = SynthesizerTrn(
Expand Down Expand Up @@ -115,7 +119,8 @@ def get_audio(self, voice, auto_break=False):
max = voice.get("max", 50)
# sentence_list = sentence_split_and_markup(text, max, "ZH", ["zh"])
if lang == "auto":
lang = classify_language(text, target_languages=lang_dict["bert_vits2"])
lang = classify_language(text, target_languages=self.lang)

sentence_list = cut(text, max)
audios = []
for sentence in sentence_list:
Expand Down
5 changes: 4 additions & 1 deletion vits/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import no_grad, LongTensor, inference_mode, FloatTensor
import utils
from contants import ModelType
from utils import get_hparams_from_file
from utils import get_hparams_from_file, lang_dict
from utils.sentence import sentence_split_and_markup
from vits import commons
from vits.mel_processing import spectrogram_torch
Expand Down Expand Up @@ -39,6 +39,9 @@ def __init__(self, model, config, additional_model=None, model_type=None, device
_ = self.net_g_ms.eval()
self.device = device

key = getattr(self.hps_ms.data, "text_cleaners", ["none"])[0]
self.lang = lang_dict.get(key, ["unknown"])

# load model
self.load_model(model, additional_model)

Expand Down

0 comments on commit f94d9fd

Please sign in to comment.