diff --git a/bert_vits2/bert_vits2.py b/bert_vits2/bert_vits2.py index cbcca34..cae1900 100644 --- a/bert_vits2/bert_vits2.py +++ b/bert_vits2/bert_vits2.py @@ -6,6 +6,7 @@ from bert_vits2.models import SynthesizerTrn from bert_vits2.text import * from bert_vits2.text.cleaner import clean_text +from bert_vits2.utils import process_legacy_versions from utils import classify_language, get_hparams_from_file, lang_dict from utils.sentence import sentence_split_and_markup, cut @@ -16,9 +17,18 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0) self.speakers = [item[0] for item in sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])] + self.symbols = symbols + + # Compatible with legacy versions + self.version = process_legacy_versions(self.hps_ms) + + if self.version in ["1.0", "1.0.1"]: + self.symbols = symbols_legacy + self.hps_ms.model.n_layers_trans_flow = 3 + + if self.version in ["1.1"]: + self.hps_ms.model.n_layers_trans_flow = 6 - self.legacy = kwargs.get('legacy', getattr(self.hps_ms.data, 'legacy', None)) - self.symbols = symbols_legacy if self.legacy else symbols self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} self.net_g = SynthesizerTrn( @@ -33,7 +43,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): self.load_model(model) def load_model(self, model): - bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, legacy=self.legacy) + bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, legacy_version=self.version) def get_speakers(self): return self.speakers diff --git a/bert_vits2/models.py b/bert_vits2/models.py index f793a84..72050c2 100644 --- a/bert_vits2/models.py +++ b/bert_vits2/models.py @@ -26,9 +26,11 @@ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_cha self.drop = nn.Dropout(p_dropout) self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) - # self.norm_1 = modules.LayerNorm(filter_channels) - self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) - # self.norm_2 = modules.LayerNorm(filter_channels) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = modules.LayerNorm(filter_channels) self.dur_proj = nn.Conv1d(1, filter_channels, 1) self.pre_out_conv_1 = nn.Conv1d(2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) @@ -36,8 +38,8 @@ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_cha self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) self.pre_out_norm_2 = modules.LayerNorm(filter_channels) - # if gin_channels != 0: - # self.cond = nn.Conv1d(gin_channels, in_channels, 1) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) self.output_layer = nn.Sequential( nn.Linear(filter_channels, 1), @@ -48,13 +50,13 @@ def forward_probability(self, x, x_mask, dur, g=None): dur = self.dur_proj(dur) x = torch.cat([x, dur], dim=1) x = self.pre_out_conv_1(x * x_mask) - # x = torch.relu(x) - # x = self.pre_out_norm_1(x) - # x = self.drop(x) + x = torch.relu(x) + x = self.pre_out_norm_1(x) + x = self.drop(x) x = self.pre_out_conv_2(x * x_mask) - # x = torch.relu(x) - # x = self.pre_out_norm_2(x) - # x = self.drop(x) + x = torch.relu(x) + x = self.pre_out_norm_2(x) + x = self.drop(x) x = x * x_mask x = x.transpose(1, 2) output_prob = self.output_layer(x) @@ -62,17 +64,17 @@ def forward_probability(self, x, x_mask, dur, g=None): def forward(self, x, x_mask, dur_r, dur_hat, g=None): x = torch.detach(x) - # if g is not None: - # g = torch.detach(g) - # x = x + self.cond(g) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) x = self.conv_1(x * x_mask) - # x = torch.relu(x) - # x = self.norm_1(x) - # x = self.drop(x) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) x = self.conv_2(x * x_mask) - # x = torch.relu(x) - # x = self.norm_2(x) - # x = self.drop(x) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) output_probs = [] for dur in [dur_r, dur_hat]: @@ -590,7 +592,7 @@ def __init__(self, gin_channels=256, use_sdp=True, n_flow_layer=4, - n_layers_trans_flow=3, + n_layers_trans_flow=6, flow_share_parameter=False, use_transformer_flow=True, **kwargs): diff --git a/bert_vits2/utils.py b/bert_vits2/utils.py index a286ae0..53f2332 100644 --- a/bert_vits2/utils.py +++ b/bert_vits2/utils.py @@ -9,7 +9,7 @@ logger = logging -def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, legacy=False): +def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, legacy_version=None): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') iteration = checkpoint_dict['iteration'] @@ -39,11 +39,11 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False # For upgrading from the old version if "ja_bert_proj" in k: v = torch.zeros_like(v) - if not legacy: + if legacy_version is None: logger.error(f"{k} is not in the checkpoint") logger.warning( - f"If you are using an older version of the model, you should add the parameter \"legacy\" " - f"to the parameter \"data\" of the model's config.json. For example: \"legacy\": \"1.0.1\"") + f"If you are using an older version of the model, you should add the parameter \"legacy_version\" " + f"to the parameter \"data\" of the model's config.json. For example: \"legacy_version\": \"1.0.1\"") else: logger.error(f"{k} is not in the checkpoint") @@ -56,3 +56,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False logger.info("Loaded checkpoint '{}' (iteration {})".format( checkpoint_path, iteration)) return model, optimizer, learning_rate, iteration + + +def process_legacy_versions(hps): + legacy_version = getattr(hps.data, "legacy", getattr(hps.data, "legacy_version", None)) + if legacy_version: + prefix = legacy_version[0].lower() + if prefix == "v": + legacy_version = legacy_version[1:] + return legacy_version diff --git a/utils/lang_dict.py b/utils/lang_dict.py index 478ad11..e556267 100644 --- a/utils/lang_dict.py +++ b/utils/lang_dict.py @@ -19,5 +19,6 @@ "YB"], "bert_chinese_cleaners": ["zh"], "bert_vits2": ["zh", "ja"], + "bert_vits2_v1.0": ["zh"], "bert_vits2_v1.0.1": ["zh"] } diff --git a/utils/load_model.py b/utils/load_model.py index 371287c..c28f6f6 100644 --- a/utils/load_model.py +++ b/utils/load_model.py @@ -5,6 +5,7 @@ import numpy as np import utils +from bert_vits2.utils import process_legacy_versions from utils.data_utils import check_is_none, HParams from vits import VITS from voice import TTS @@ -93,15 +94,6 @@ def parse_models(model_list): return categorized_models -def process_legacy_versions(hps): - legacy_versions = getattr(hps.data, "legacy", None) - if legacy_versions: - prefix = legacy_versions[0].lower() - if prefix == "v": - legacy_versions = legacy_versions[1:] - return legacy_versions - - def merge_models(model_list, model_class, model_type, additional_arg=None): id_mapping_objs = [] speakers = [] @@ -117,7 +109,6 @@ def merge_models(model_list, model_class, model_type, additional_arg=None): if model_type == "bert_vits2": legacy_versions = process_legacy_versions(hps) - obj_args.update({"legacy": legacy_versions}) key = f"{model_type}_v{legacy_versions}" if legacy_versions else model_type else: key = getattr(hps.data, "text_cleaners", ["none"])[0]