Skip to content

Commit

Permalink
Update Bert-VITS2 num_tones
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Nov 6, 2023
1 parent 940bfe6 commit ed1897f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 4 deletions.
7 changes: 7 additions & 0 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,38 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):
self.bert_model_names = {"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE"}
self.ja_bert_dim = 1024
self.ja_extra_str = ""
self.num_tones = num_tones

if self.version in ["1.0", "1.0.0", "1.0.1"]:
self.symbols = symbols_legacy
self.hps_ms.model.n_layers_trans_flow = 3
self.lang = ["zh"]
self.ja_bert_dim = 768
self.num_tones = num_tones_v111

elif self.version in ["1.1.0-transition"]:
self.hps_ms.model.n_layers_trans_flow = 3
self.lang = ["zh", "ja"]
self.bert_model_names["ja"] = "BERT_BASE_JAPANESE_V3"
self.ja_bert_dim = 768
self.ja_extra_str = "_v111"
self.num_tones = num_tones_v111

elif self.version in ["1.1", "1.1.0", "1.1.1"]:
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = ["zh", "ja"]
self.bert_model_names["ja"] = "BERT_BASE_JAPANESE_V3"
self.ja_bert_dim = 768
self.ja_extra_str = "_v111"
self.num_tones = num_tones_v111

elif self.version in ["2.0", "2.0.0"]:
self.hps_ms.model.n_layers_trans_flow = 4
self.bert_model_names = {"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE",
"ja": "DEBERTA_V2_LARGE_JAPANESE",
"en": "DEBERTA_V3_LARGE"}
self.num_tones = num_tones


# self.bert_handler = BertHandler(self.lang)

Expand All @@ -69,6 +75,7 @@ def load_model(self, bert_handler):
n_speakers=self.hps_ms.data.n_speakers,
symbols=self.symbols,
ja_bert_dim=self.ja_bert_dim,
num_tones=self.num_tones,
**self.hps_ms.model).to(self.device)
_ = self.net_g.eval()
bert_vits2_utils.load_checkpoint(self.model_path, self.net_g, None, skip_optimizer=True, version=self.version)
Expand Down
9 changes: 6 additions & 3 deletions bert_vits2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm

from bert_vits2.commons import init_weights, get_padding
from bert_vits2.text import num_tones, num_languages
from bert_vits2.text import num_languages


class DurationDiscriminator(nn.Module): # vits2
Expand Down Expand Up @@ -258,7 +258,8 @@ def __init__(self,
p_dropout,
gin_channels=0,
symbols=None,
ja_bert_dim=1024):
ja_bert_dim=1024,
num_tones=None):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
Expand Down Expand Up @@ -601,6 +602,7 @@ def __init__(self,
use_transformer_flow=True,
symbols=None,
ja_bert_dim=1024,
num_tones=None,
**kwargs):

super().__init__()
Expand Down Expand Up @@ -641,7 +643,8 @@ def __init__(self,
p_dropout,
gin_channels=self.enc_gin_channels,
symbols=symbols,
ja_bert_dim=ja_bert_dim
ja_bert_dim=ja_bert_dim,
num_tones=num_tones
)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
Expand Down
9 changes: 9 additions & 0 deletions bert_vits2/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
from bert_vits2.text.bert_handler import BertHandler


def cleaned_text_to_sequence_v111(cleaned_text, tones, language, _symbol_to_id):
"""version <= 1.1.1"""
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
tone_start = language_tone_start_map_v111[language]
tones = [i + tone_start for i in tones]
lang_id = language_id_map[language]
lang_ids = [lang_id for i in phones]
return phones, tones, lang_ids

def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
Expand Down
10 changes: 9 additions & 1 deletion bert_vits2/text/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@
"z",
"zy",
]
num_ja_tones = 1
num_ja_tones_v111 = 1
num_ja_tones = 2

# English
en_symbols = [
Expand Down Expand Up @@ -176,12 +177,19 @@
sil_phonemes_ids_legacy = [symbols_legacy.index(i) for i in pu_symbols]

# combine all tones
num_tones_v111 = num_zh_tones + num_ja_tones_v111 + num_en_tones
num_tones = num_zh_tones + num_ja_tones + num_en_tones

# language maps
language_id_map = {"zh": 0, "ja": 1, "en": 2}
num_languages = len(language_id_map.keys())

language_tone_start_map_v111 = {
"zh": 0,
"ja": num_zh_tones,
"en": num_zh_tones + num_ja_tones_v111,
}

language_tone_start_map = {
"zh": 0,
"ja": num_zh_tones,
Expand Down

0 comments on commit ed1897f

Please sign in to comment.