Skip to content

Commit

Permalink
update: bert_vits2 japanese
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Sep 21, 2023
1 parent 893ee70 commit 0d9ebf6
Show file tree
Hide file tree
Showing 13 changed files with 1,330 additions and 261 deletions.
2 changes: 1 addition & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flask_apscheduler import APScheduler
from functools import wraps
from utils.utils import clean_folder, check_is_none
from utils.merge import merge_model
from utils.load_model import merge_model
from io import BytesIO

app = Flask(__name__)
Expand Down
15 changes: 13 additions & 2 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from bert_vits2.models import SynthesizerTrn
from bert_vits2.text import symbols, cleaned_text_to_sequence, get_bert
from bert_vits2.text.cleaner import clean_text
from utils.nlp import sentence_split, cut
from bert_vits2.text.symbols import get_symbols
from utils.sentence import sentence_split, cut


class Bert_VITS2:
Expand All @@ -16,11 +17,20 @@ def __init__(self, model, config, device=torch.device("cpu")):
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
self.speakers = [item[0] for item in
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]

self.legacy = getattr(self.hps_ms.data, 'legacy', False)
symbols, num_tones, self.language_id_map, num_languages, self.language_tone_start_map = get_symbols(
legacy=self.legacy)
self._symbol_to_id = {s: i for i, s in enumerate(symbols)}

self.net_g = SynthesizerTrn(
len(symbols),
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.hps_ms.data.n_speakers,
symbols=symbols,
num_tones=num_tones,
num_languages=num_languages,
**self.hps_ms.model).to(device)
_ = self.net_g.eval()
self.device = device
Expand All @@ -35,7 +45,8 @@ def get_speakers(self):
def get_text(self, text, language_str, hps):
norm_text, phone, tone, word2ph = clean_text(text, language_str)
# print([f"{p}{t}" for p, t in zip(phone, tone)])
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id,
self.language_tone_start_map, self.language_id_map)

if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
Expand Down
15 changes: 12 additions & 3 deletions bert_vits2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm

from bert_vits2.commons import init_weights, get_padding
from bert_vits2.text import symbols, num_tones, num_languages


class DurationDiscriminator(nn.Module): # vits2
Expand Down Expand Up @@ -254,7 +253,10 @@ def __init__(self,
n_layers,
kernel_size,
p_dropout,
gin_channels=0):
gin_channels=0,
symbols=None,
num_tones=None,
num_languages=None):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
Expand Down Expand Up @@ -620,6 +622,9 @@ def __init__(self,
self.current_mas_noise_scale = self.mas_noise_scale_initial
if self.use_spk_conditioned_encoder and gin_channels > 0:
self.enc_gin_channels = gin_channels
symbols = kwargs.get("symbols")
num_tones = kwargs.get("num_tones")
num_languages = kwargs.get("num_languages")
self.enc_p = TextEncoder(n_vocab,
inter_channels,
hidden_channels,
Expand All @@ -628,7 +633,11 @@ def __init__(self,
n_layers,
kernel_size,
p_dropout,
gin_channels=self.enc_gin_channels)
gin_channels=self.enc_gin_channels,
symbols=symbols,
num_tones=num_tones,
num_languages=num_languages
)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
Expand Down
37 changes: 19 additions & 18 deletions bert_vits2/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from bert_vits2.text.symbols import *
from .chinese_bert import get_bert_feature as zh_bert
from .english_bert_mock import get_bert_feature as en_bert
from bert_vits2.text.symbols import punctuation

_symbol_to_id = {s: i for i, s in enumerate(symbols)}


def cleaned_text_to_sequence(cleaned_text, tones, language):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
'''
def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id, language_tone_start_map, language_id_map):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
"""
phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
tone_start = language_tone_start_map[language]
tones = [i + tone_start for i in tones]
Expand All @@ -21,9 +16,15 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):


def get_bert(norm_text, word2ph, language):
lang_bert_func_map = {
'ZH': zh_bert,
'EN': en_bert
}
bert = lang_bert_func_map[language](norm_text, word2ph)
if language == "ZH":
from .chinese_bert import get_bert_feature as zh_bert
lang_bert_func = zh_bert
elif language == "EN":
from .english_bert_mock import get_bert_feature as en_bert
lang_bert_func = en_bert
elif language == "JP":
from .japanese_bert import get_bert_feature as jp_bert
lang_bert_func = jp_bert

bert = lang_bert_func(norm_text, word2ph)
return bert
9 changes: 3 additions & 6 deletions bert_vits2/text/chinese_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
from transformers import AutoTokenizer, AutoModelForMaskedLM
from logger import logger

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
logger.info("Loading chinese-roberta-wwm-ext-large...")
tokenizer = AutoTokenizer.from_pretrained(config.ABS_PATH + "/bert_vits2/bert/chinese-roberta-wwm-ext-large")
model = AutoModelForMaskedLM.from_pretrained(config.ABS_PATH + "/bert_vits2/bert/chinese-roberta-wwm-ext-large").to(
device)
config.DEVICE)
logger.info("Loading finished.")
except Exception as e:
logger.error(e)
logger.error(f"Please download model from hfl/chinese-roberta-wwm-ext-large.")
logger.error(f"Please download pytorch_model.bin from hfl/chinese-roberta-wwm-ext-large.")


def get_bert_feature(text, word2ph):
def get_bert_feature(text, word2ph, device=config.DEVICE):
with torch.no_grad():
inputs = tokenizer(text, return_tensors='pt')
for i in inputs:
Expand All @@ -37,7 +35,6 @@ def get_bert_feature(text, word2ph):


if __name__ == '__main__':
# feature = get_bert_feature('你好,我是说的道理。')
import torch

word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
Expand Down
Loading

0 comments on commit 0d9ebf6

Please sign in to comment.