diff --git a/TTSManager.py b/TTSManager.py index d482646..658a60e 100644 --- a/TTSManager.py +++ b/TTSManager.py @@ -244,6 +244,18 @@ def vits_infer(self, state, encode=True): audio = np.concatenate(audios, axis=0) return self.encode(sampling_rate, audio, state["format"]) if encode else audio + def stream_vits_infer(self, task, fname=None): + format = task.get("format", "wav") + voice_obj = self._voice_obj[ModelType.VITS][task.get("id")][1] + task["id"] = self._voice_obj[ModelType.VITS][task.get("id")][0] + sampling_rate = voice_obj.sampling_rate + genertator = voice_obj.get_stream_audio(task, auto_break=True) + # audio = BytesIO() + for chunk in genertator: + encoded_audio = self.encode(sampling_rate, chunk, format) + for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio): + yield encoded_audio_chunk + def stream_vits_infer(self, state, fname=None): model = self.get_model(ModelType.VITS, state["id"]) state["id"] = self.get_real_id(ModelType.VITS, state["id"]) @@ -266,11 +278,11 @@ def stream_vits_infer(self, state, fname=None): if i < sentences_num - 1: audios.append(brk) - audio = np.concatenate(audios, axis=0) - encoded_audio = self.encode(sampling_rate, audio, state["format"]) + audio = np.concatenate(audios, axis=0) + encoded_audio = self.encode(sampling_rate, audio, state["format"]) - for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio): - yield encoded_audio_chunk + for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio): + yield encoded_audio_chunk # if getattr(config, "SAVE_AUDIO", False): # audio.write(encoded_audio.getvalue()) # if getattr(config, "SAVE_AUDIO", False): @@ -357,3 +369,28 @@ def bert_vits2_infer(self, state, encode=True): audio = np.concatenate(audios) return self.encode(sampling_rate, audio, state["format"]) if encode else audio + + def stream_bert_vits2_infer(self, state, fname=None): + model = self.get_model(ModelType.BERT_VITS2, state["id"]) + state["id"] = self.get_real_id(ModelType.BERT_VITS2, state["id"]) + + # 去除所有多余的空白字符 + if state["text"] is not None: + state["text"] = re.sub(r'\s+', ' ', state["text"]).strip() + sampling_rate = model.sampling_rate + + sentences_list = split_by_language(state["text"], state["speaker_lang"]) + + # audios = [] + + for (text, lang) in sentences_list: + sentences = sentence_split(text, state["max"]) + for sentence in sentences: + audio = model.infer(sentence, state["id"], lang, state["sdp_ratio"], state["noise"], + state["noise"], state["length"]) + # audios.append(audio) + # audio = np.concatenate(audios, axis=0) + encoded_audio = self.encode(sampling_rate, audio, state["format"]) + + for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio): + yield encoded_audio_chunk diff --git a/tts_app/templates/pages/index.html b/tts_app/templates/pages/index.html index 6722d5e..8138d4a 100644 --- a/tts_app/templates/pages/index.html +++ b/tts_app/templates/pages/index.html @@ -295,6 +295,11 @@

Your browser does not support the audio element. +
+ + +
@@ -392,6 +397,7 @@

} else if (model_type == 3) { var sdp_ratio = document.getElementById("input_sdp_ratio").value; var url = baseUrl + "/voice/bert-vits2?text=" + text + "&id=" + id; + var streaming = document.getElementById('streaming3'); } if (format != "") { url += "&format=" + format; @@ -411,7 +417,7 @@

if (max != "") { url += "&max=" + max; } - if (model_type == 1 && streaming.checked) { + if ((model_type == 1 || model_type == 3) && streaming.checked) { url += '&streaming=true'; } if (model_type == 3 && sdp_ratio != "") { diff --git a/tts_app/voice_api/views.py b/tts_app/voice_api/views.py index 2aa6991..bdc92d1 100644 --- a/tts_app/voice_api/views.py +++ b/tts_app/voice_api/views.py @@ -387,6 +387,7 @@ def voice_bert_vits2_api(): noisew = float(request_data.get("noisew", current_app.config.get("NOISEW", 0.8))) sdp_ratio = float(request_data.get("sdp_ratio", current_app.config.get("SDP_RATIO", 0.2))) max = int(request_data.get("max", current_app.config.get("MAX", 50))) + use_streaming = request_data.get('streaming', False, type=bool) except Exception as e: logger.error(f"[{ModelType.BERT_VITS2.value}] {e}") return make_response("parameter error", 400) @@ -418,9 +419,13 @@ def voice_bert_vits2_api(): if current_app.config.get("LANGUAGE_AUTOMATIC_DETECT", []) != []: speaker_lang = current_app.config.get("LANGUAGE_AUTOMATIC_DETECT") + if use_streaming and format.upper() != "MP3": + format = "mp3" + logger.warning("Streaming response only supports MP3 format.") + fname = f"{str(uuid.uuid1())}.{format}" file_type = f"audio/{format}" - task = {"text": text, + state = {"text": text, "id": id, "format": format, "length": length, @@ -430,11 +435,20 @@ def voice_bert_vits2_api(): "max": max, "lang": lang, "speaker_lang": speaker_lang} + + if use_streaming: + audio = tts_manager.stream_bert_vits2_infer(state) + response = make_response(audio) + response.headers['Content-Disposition'] = f'attachment; filename={fname}' + response.headers['Content-Type'] = file_type + return response + else: + t1 = time.time() + audio = tts_manager.bert_vits2_infer(state) + t2 = time.time() + logger.info(f"[{ModelType.BERT_VITS2.value}] finish in {(t2 - t1):.2f}s") - t1 = time.time() - audio = tts_manager.bert_vits2_infer(task) - t2 = time.time() - logger.info(f"[{ModelType.BERT_VITS2.value}] finish in {(t2 - t1):.2f}s") + if current_app.config.get("SAVE_AUDIO", False): logger.debug(f"[{ModelType.BERT_VITS2.value}] {fname}") diff --git a/utils/sentence.py b/utils/sentence.py index 7833d29..30773d0 100644 --- a/utils/sentence.py +++ b/utils/sentence.py @@ -93,8 +93,8 @@ def sentence_split(text: str, max: int) -> list: if p < len(text): sentences_list.append(text[p:]) - for i in sentences_list: - logging.debug(i) + # for i in sentences_list: + # logging.debug(i) return sentences_list