Skip to content

Commit

Permalink
Add stream infer for Bert-VITS2.
Browse files Browse the repository at this point in the history
Fix vits stream infer.
  • Loading branch information
Artrajz committed Nov 7, 2023
1 parent 2f04748 commit 4351b60
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 12 deletions.
45 changes: 41 additions & 4 deletions TTSManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,18 @@ def vits_infer(self, state, encode=True):
audio = np.concatenate(audios, axis=0)
return self.encode(sampling_rate, audio, state["format"]) if encode else audio

def stream_vits_infer(self, task, fname=None):
format = task.get("format", "wav")
voice_obj = self._voice_obj[ModelType.VITS][task.get("id")][1]
task["id"] = self._voice_obj[ModelType.VITS][task.get("id")][0]
sampling_rate = voice_obj.sampling_rate
genertator = voice_obj.get_stream_audio(task, auto_break=True)
# audio = BytesIO()
for chunk in genertator:
encoded_audio = self.encode(sampling_rate, chunk, format)
for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio):
yield encoded_audio_chunk

def stream_vits_infer(self, state, fname=None):
model = self.get_model(ModelType.VITS, state["id"])
state["id"] = self.get_real_id(ModelType.VITS, state["id"])
Expand All @@ -266,11 +278,11 @@ def stream_vits_infer(self, state, fname=None):
if i < sentences_num - 1:
audios.append(brk)

audio = np.concatenate(audios, axis=0)
encoded_audio = self.encode(sampling_rate, audio, state["format"])
audio = np.concatenate(audios, axis=0)
encoded_audio = self.encode(sampling_rate, audio, state["format"])

for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio):
yield encoded_audio_chunk
for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio):
yield encoded_audio_chunk
# if getattr(config, "SAVE_AUDIO", False):
# audio.write(encoded_audio.getvalue())
# if getattr(config, "SAVE_AUDIO", False):
Expand Down Expand Up @@ -357,3 +369,28 @@ def bert_vits2_infer(self, state, encode=True):
audio = np.concatenate(audios)

return self.encode(sampling_rate, audio, state["format"]) if encode else audio

def stream_bert_vits2_infer(self, state, fname=None):
model = self.get_model(ModelType.BERT_VITS2, state["id"])
state["id"] = self.get_real_id(ModelType.BERT_VITS2, state["id"])

# 去除所有多余的空白字符
if state["text"] is not None:
state["text"] = re.sub(r'\s+', ' ', state["text"]).strip()
sampling_rate = model.sampling_rate

sentences_list = split_by_language(state["text"], state["speaker_lang"])

# audios = []

for (text, lang) in sentences_list:
sentences = sentence_split(text, state["max"])
for sentence in sentences:
audio = model.infer(sentence, state["id"], lang, state["sdp_ratio"], state["noise"],
state["noise"], state["length"])
# audios.append(audio)
# audio = np.concatenate(audios, axis=0)
encoded_audio = self.encode(sampling_rate, audio, state["format"])

for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio):
yield encoded_audio_chunk
8 changes: 7 additions & 1 deletion tts_app/templates/pages/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ <h1 class="w-100">
<source src="" type="audio/mp3"/>
Your browser does not support the audio element.
</audio>
<div class="mb-3 form-check">
<input type="checkbox" id="streaming3" onchange="updateLink()">
<label class="form-check-label" data-toggle="tooltip" data-placement="top"
title="按照max分段推理文本,推理好一段即输出,无需等待所有文本都推理完毕">流式响应</label>
</div>
</div>
</div>
</div>
Expand Down Expand Up @@ -392,6 +397,7 @@ <h1 class="w-100">
} else if (model_type == 3) {
var sdp_ratio = document.getElementById("input_sdp_ratio").value;
var url = baseUrl + "/voice/bert-vits2?text=" + text + "&id=" + id;
var streaming = document.getElementById('streaming3');
}
if (format != "") {
url += "&format=" + format;
Expand All @@ -411,7 +417,7 @@ <h1 class="w-100">
if (max != "") {
url += "&max=" + max;
}
if (model_type == 1 && streaming.checked) {
if ((model_type == 1 || model_type == 3) && streaming.checked) {
url += '&streaming=true';
}
if (model_type == 3 && sdp_ratio != "") {
Expand Down
24 changes: 19 additions & 5 deletions tts_app/voice_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def voice_bert_vits2_api():
noisew = float(request_data.get("noisew", current_app.config.get("NOISEW", 0.8)))
sdp_ratio = float(request_data.get("sdp_ratio", current_app.config.get("SDP_RATIO", 0.2)))
max = int(request_data.get("max", current_app.config.get("MAX", 50)))
use_streaming = request_data.get('streaming', False, type=bool)
except Exception as e:
logger.error(f"[{ModelType.BERT_VITS2.value}] {e}")
return make_response("parameter error", 400)
Expand Down Expand Up @@ -418,9 +419,13 @@ def voice_bert_vits2_api():
if current_app.config.get("LANGUAGE_AUTOMATIC_DETECT", []) != []:
speaker_lang = current_app.config.get("LANGUAGE_AUTOMATIC_DETECT")

if use_streaming and format.upper() != "MP3":
format = "mp3"
logger.warning("Streaming response only supports MP3 format.")

fname = f"{str(uuid.uuid1())}.{format}"
file_type = f"audio/{format}"
task = {"text": text,
state = {"text": text,
"id": id,
"format": format,
"length": length,
Expand All @@ -430,11 +435,20 @@ def voice_bert_vits2_api():
"max": max,
"lang": lang,
"speaker_lang": speaker_lang}

if use_streaming:
audio = tts_manager.stream_bert_vits2_infer(state)
response = make_response(audio)
response.headers['Content-Disposition'] = f'attachment; filename={fname}'
response.headers['Content-Type'] = file_type
return response
else:
t1 = time.time()
audio = tts_manager.bert_vits2_infer(state)
t2 = time.time()
logger.info(f"[{ModelType.BERT_VITS2.value}] finish in {(t2 - t1):.2f}s")

t1 = time.time()
audio = tts_manager.bert_vits2_infer(task)
t2 = time.time()
logger.info(f"[{ModelType.BERT_VITS2.value}] finish in {(t2 - t1):.2f}s")


if current_app.config.get("SAVE_AUDIO", False):
logger.debug(f"[{ModelType.BERT_VITS2.value}] {fname}")
Expand Down
4 changes: 2 additions & 2 deletions utils/sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def sentence_split(text: str, max: int) -> list:
if p < len(text):
sentences_list.append(text[p:])

for i in sentences_list:
logging.debug(i)
# for i in sentences_list:
# logging.debug(i)

return sentences_list

Expand Down

0 comments on commit 4351b60

Please sign in to comment.