Skip to content

Commit

Permalink
update: Bert_VITS2 v1.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Oct 5, 2023
1 parent 014f71c commit f71f438
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 38 deletions.
16 changes: 13 additions & 3 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bert_vits2.models import SynthesizerTrn
from bert_vits2.text import *
from bert_vits2.text.cleaner import clean_text
from bert_vits2.utils import process_legacy_versions
from utils import classify_language, get_hparams_from_file, lang_dict
from utils.sentence import sentence_split_and_markup, cut

Expand All @@ -16,9 +17,18 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
self.speakers = [item[0] for item in
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
self.symbols = symbols

# Compatible with legacy versions
self.version = process_legacy_versions(self.hps_ms)

if self.version in ["1.0", "1.0.1"]:
self.symbols = symbols_legacy
self.hps_ms.model.n_layers_trans_flow = 3

if self.version in ["1.1"]:
self.hps_ms.model.n_layers_trans_flow = 6

self.legacy = kwargs.get('legacy', getattr(self.hps_ms.data, 'legacy', None))
self.symbols = symbols_legacy if self.legacy else symbols
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}

self.net_g = SynthesizerTrn(
Expand All @@ -33,7 +43,7 @@ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
self.load_model(model)

def load_model(self, model):
bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, legacy=self.legacy)
bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, legacy_version=self.version)

def get_speakers(self):
return self.speakers
Expand Down
44 changes: 23 additions & 21 deletions bert_vits2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_cha

self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
# self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
# self.norm_2 = modules.LayerNorm(filter_channels)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels)
self.dur_proj = nn.Conv1d(1, filter_channels, 1)

self.pre_out_conv_1 = nn.Conv1d(2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)

# if gin_channels != 0:
# self.cond = nn.Conv1d(gin_channels, in_channels, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)

self.output_layer = nn.Sequential(
nn.Linear(filter_channels, 1),
Expand All @@ -48,31 +50,31 @@ def forward_probability(self, x, x_mask, dur, g=None):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
# x = torch.relu(x)
# x = self.pre_out_norm_1(x)
# x = self.drop(x)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.drop(x)
x = self.pre_out_conv_2(x * x_mask)
# x = torch.relu(x)
# x = self.pre_out_norm_2(x)
# x = self.drop(x)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.drop(x)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob

def forward(self, x, x_mask, dur_r, dur_hat, g=None):
x = torch.detach(x)
# if g is not None:
# g = torch.detach(g)
# x = x + self.cond(g)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
# x = torch.relu(x)
# x = self.norm_1(x)
# x = self.drop(x)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
# x = torch.relu(x)
# x = self.norm_2(x)
# x = self.drop(x)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)

output_probs = []
for dur in [dur_r, dur_hat]:
Expand Down Expand Up @@ -590,7 +592,7 @@ def __init__(self,
gin_channels=256,
use_sdp=True,
n_flow_layer=4,
n_layers_trans_flow=3,
n_layers_trans_flow=6,
flow_share_parameter=False,
use_transformer_flow=True,
**kwargs):
Expand Down
17 changes: 13 additions & 4 deletions bert_vits2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
logger = logging


def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, legacy=False):
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, legacy_version=None):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
Expand Down Expand Up @@ -39,11 +39,11 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
# For upgrading from the old version
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
if not legacy:
if legacy_version is None:
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you are using an older version of the model, you should add the parameter \"legacy\" "
f"to the parameter \"data\" of the model's config.json. For example: \"legacy\": \"1.0.1\"")
f"If you are using an older version of the model, you should add the parameter \"legacy_version\" "
f"to the parameter \"data\" of the model's config.json. For example: \"legacy_version\": \"1.0.1\"")
else:
logger.error(f"{k} is not in the checkpoint")

Expand All @@ -56,3 +56,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
logger.info("Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration


def process_legacy_versions(hps):
legacy_version = getattr(hps.data, "legacy", getattr(hps.data, "legacy_version", None))
if legacy_version:
prefix = legacy_version[0].lower()
if prefix == "v":
legacy_version = legacy_version[1:]
return legacy_version
1 change: 1 addition & 0 deletions utils/lang_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@
"YB"],
"bert_chinese_cleaners": ["zh"],
"bert_vits2": ["zh", "ja"],
"bert_vits2_v1.0": ["zh"],
"bert_vits2_v1.0.1": ["zh"]
}
11 changes: 1 addition & 10 deletions utils/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

import utils
from bert_vits2.utils import process_legacy_versions
from utils.data_utils import check_is_none, HParams
from vits import VITS
from voice import TTS
Expand Down Expand Up @@ -93,15 +94,6 @@ def parse_models(model_list):
return categorized_models


def process_legacy_versions(hps):
legacy_versions = getattr(hps.data, "legacy", None)
if legacy_versions:
prefix = legacy_versions[0].lower()
if prefix == "v":
legacy_versions = legacy_versions[1:]
return legacy_versions


def merge_models(model_list, model_class, model_type, additional_arg=None):
id_mapping_objs = []
speakers = []
Expand All @@ -117,7 +109,6 @@ def merge_models(model_list, model_class, model_type, additional_arg=None):

if model_type == "bert_vits2":
legacy_versions = process_legacy_versions(hps)
obj_args.update({"legacy": legacy_versions})
key = f"{model_type}_v{legacy_versions}" if legacy_versions else model_type
else:
key = getattr(hps.data, "text_cleaners", ["none"])[0]
Expand Down

0 comments on commit f71f438

Please sign in to comment.