Skip to content

Commit

Permalink
update: version info
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Oct 23, 2023
1 parent 45e6505 commit bc391ae
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
self.load_model(model)

def load_model(self, model):
bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, legacy_version=self.version)
bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, version=self.version)

def get_speakers(self):
return self.speakers
Expand Down
8 changes: 4 additions & 4 deletions bert_vits2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
logger = logging


def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, legacy_version=None):
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, version=None):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
Expand Down Expand Up @@ -39,14 +39,14 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
# Handle legacy model versions and provide appropriate warnings
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
if legacy_version is None:
if version is None:
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you're using an older version of the model, consider adding the \"legacy_version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"")
f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"")
elif "flow.flows.0.enc.attn_layers.3" in k:
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you're using a transitional version, please add the \"legacy_version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.")
f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.")
else:
logger.error(f"{k} is not in the checkpoint")

Expand Down

0 comments on commit bc391ae

Please sign in to comment.