diff --git a/bert_vits2/bert_vits2.py b/bert_vits2/bert_vits2.py index c0d2341..a48544b 100644 --- a/bert_vits2/bert_vits2.py +++ b/bert_vits2/bert_vits2.py @@ -26,7 +26,10 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): self.symbols = symbols_legacy self.hps_ms.model.n_layers_trans_flow = 3 - if self.version in ["1.1", "1.1.0", "1.1.1"]: + elif self.version in ["1.1.0-transition"]: + self.hps_ms.model.n_layers_trans_flow = 3 + + elif self.version in ["1.1", "1.1.0", "1.1.1"]: self.hps_ms.model.n_layers_trans_flow = 6 self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} diff --git a/bert_vits2/utils.py b/bert_vits2/utils.py index 53f2332..dcecaee 100644 --- a/bert_vits2/utils.py +++ b/bert_vits2/utils.py @@ -36,14 +36,17 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False new_state_dict[k] = saved_state_dict[k] assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) except: - # For upgrading from the old version + # Handle legacy model versions and provide appropriate warnings if "ja_bert_proj" in k: v = torch.zeros_like(v) if legacy_version is None: logger.error(f"{k} is not in the checkpoint") logger.warning( - f"If you are using an older version of the model, you should add the parameter \"legacy_version\" " - f"to the parameter \"data\" of the model's config.json. For example: \"legacy_version\": \"1.0.1\"") + f"If you're using an older version of the model, consider adding the \"legacy_version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"") + elif "flow.flows.0.enc.attn_layers.3" in k: + logger.error(f"{k} is not in the checkpoint") + logger.warning( + f"If you're using a transitional version, please add the \"legacy_version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.") else: logger.error(f"{k} is not in the checkpoint") diff --git a/utils/lang_dict.py b/utils/lang_dict.py index a3c74e4..2c72caf 100644 --- a/utils/lang_dict.py +++ b/utils/lang_dict.py @@ -24,5 +24,6 @@ "bert_vits2_v1.0.1": ["zh"], "bert_vits2_v1.1": ["zh", "ja"], "bert_vits2_v1.1.0": ["zh", "ja"], + "bert_vits2_v1.1.0-transition": ["zh", "ja"], "bert_vits2_v1.1.1": ["zh", "ja"], }