Skip to content

Commit

Permalink
update: Compatible with the transitional version of Bert-VITS2 v1.1.
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Oct 6, 2023
1 parent d91729f commit bf375ca
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
5 changes: 4 additions & 1 deletion bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
self.symbols = symbols_legacy
self.hps_ms.model.n_layers_trans_flow = 3

if self.version in ["1.1", "1.1.0", "1.1.1"]:
elif self.version in ["1.1.0-transition"]:
self.hps_ms.model.n_layers_trans_flow = 3

elif self.version in ["1.1", "1.1.0", "1.1.1"]:
self.hps_ms.model.n_layers_trans_flow = 6

self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
Expand Down
9 changes: 6 additions & 3 deletions bert_vits2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,17 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
except:
# For upgrading from the old version
# Handle legacy model versions and provide appropriate warnings
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
if legacy_version is None:
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you are using an older version of the model, you should add the parameter \"legacy_version\" "
f"to the parameter \"data\" of the model's config.json. For example: \"legacy_version\": \"1.0.1\"")
f"If you're using an older version of the model, consider adding the \"legacy_version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"")
elif "flow.flows.0.enc.attn_layers.3" in k:
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you're using a transitional version, please add the \"legacy_version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.")
else:
logger.error(f"{k} is not in the checkpoint")

Expand Down
1 change: 1 addition & 0 deletions utils/lang_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
"bert_vits2_v1.0.1": ["zh"],
"bert_vits2_v1.1": ["zh", "ja"],
"bert_vits2_v1.1.0": ["zh", "ja"],
"bert_vits2_v1.1.0-transition": ["zh", "ja"],
"bert_vits2_v1.1.1": ["zh", "ja"],
}

0 comments on commit bf375ca

Please sign in to comment.