From bc391ae930399af33d7f8f0ab69eaa034ae0ecc5 Mon Sep 17 00:00:00 2001 From: Artrajz <969242373@qq.com> Date: Mon, 23 Oct 2023 08:55:52 +0800 Subject: [PATCH] update: version info --- bert_vits2/bert_vits2.py | 2 +- bert_vits2/utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bert_vits2/bert_vits2.py b/bert_vits2/bert_vits2.py index 88c7ee9..6a3d578 100644 --- a/bert_vits2/bert_vits2.py +++ b/bert_vits2/bert_vits2.py @@ -53,7 +53,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs): self.load_model(model) def load_model(self, model): - bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, legacy_version=self.version) + bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, version=self.version) def get_speakers(self): return self.speakers diff --git a/bert_vits2/utils.py b/bert_vits2/utils.py index b195423..5ac6252 100644 --- a/bert_vits2/utils.py +++ b/bert_vits2/utils.py @@ -9,7 +9,7 @@ logger = logging -def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, legacy_version=None): +def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, version=None): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') iteration = checkpoint_dict['iteration'] @@ -39,14 +39,14 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False # Handle legacy model versions and provide appropriate warnings if "ja_bert_proj" in k: v = torch.zeros_like(v) - if legacy_version is None: + if version is None: logger.error(f"{k} is not in the checkpoint") logger.warning( - f"If you're using an older version of the model, consider adding the \"legacy_version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"") + f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"") elif "flow.flows.0.enc.attn_layers.3" in k: logger.error(f"{k} is not in the checkpoint") logger.warning( - f"If you're using a transitional version, please add the \"legacy_version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.") + f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.") else: logger.error(f"{k} is not in the checkpoint")