-
Notifications
You must be signed in to change notification settings - Fork 8
/
api.py
50 lines (47 loc) · 2.36 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from pypinyin import lazy_pinyin, Style
import torch
import json
MODELS = {
'vqvae.pth':'/home/hyc/detail_tts/logs/2024-08-18-11-29-08/model-480.pt',
}
device = 'cuda:0'
from bpe_tokenizers.voice_tokenizer import VoiceBpeTokenizer
import torch.nn.functional as F
cond_audio = '1.wav'
text = "大家好,今天来点大家想看的东西"
# text = "霞浦县衙城镇乌旗瓦窑村水位猛涨"
# text = '高德官方网站,拥有全面,精准的地点信息,公交驾车路线规划,特色语音导航,商家团购,优惠信息'
# text = '四是四,十是十,十四是十四,四十是四十'
# text = '八百标兵奔北坡,炮兵并排北边跑。炮兵怕把标兵碰,标兵怕碰炮兵炮'
# text = '黑化肥发灰,灰化肥发黑,黑化肥挥发会发灰,灰化肥挥发会发黑。'
# text = '先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也'
text = ' '.join(lazy_pinyin(text, style=Style.TONE3, neutral_tone_with_five=True))
text = ' '+text+' '
tokenizer = VoiceBpeTokenizer('bpe_tokenizers/zh_tokenizer.json')
text_tokens = torch.IntTensor(tokenizer.encode(text)).unsqueeze(0).to(device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
text_tokens = text_tokens.to(device)
print(text)
print(text_tokens)
from prepare.load_infer import load_model
import torchaudio
from vqvae.utils.data_utils import spectrogram_torch,HParams,mel_spectrogram_torch
# device = 'gpu:0'
vqvae = load_model('vqvae', MODELS['vqvae.pth'], 'vqvae/configs/config_24k.json', device)
audio,sr = torchaudio.load(cond_audio)
if audio.shape[0]>1:
audio = audio[0].unsqueeze(0)
audio = torchaudio.transforms.Resample(sr,24000)(audio)
hps = HParams(**json.load(open('vqvae/configs/config_24k.json')))
spec = mel_spectrogram_torch(audio,hps.data.filter_length,
hps.data.n_mel_channels,
hps.data.sampling_rate,
hps.data.hop_length,
hps.data.win_length,
hps.data.mel_fmin,
hps.data.mel_fmax).to(device)
spec_lengths = torch.LongTensor([spec.shape[-1]]).to(device)
text_lengths = torch.LongTensor([text_tokens.shape[-1]])
with torch.no_grad():
wav = vqvae.infer(text_tokens, text_lengths, spec, spec_lengths)
torchaudio.save('gen.wav', wav.squeeze(0).cpu(), 24000)