Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

silero tts #12

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions programs/tts/silero_tts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Silero TTS

Text to speech service for Rhasspy based on [Silero TTS](https://github.com/snakers4/silero-models).

## Installation

1. Copy the contents of this directory to `config/programs/tts/silero_tts/`
2. Run `script/setup`
3. Download a model with `script/download`
* Example: `script/download --language ru --model v3_1_ru`
* Models are downloaded to `config/data/tts/silero_tts/models` directory
4. Test with `bin/tts_synthesize.py`
*
Example `script/run bin/tts_synthesize.py --tts-program silero_tts -f test.wav --debug 'test text!'`
53 changes: 53 additions & 0 deletions programs/tts/silero_tts/bin/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3
import argparse
import logging
import tempfile
from pathlib import Path

import torch
from omegaconf import OmegaConf

_DIR = Path(__file__).parent
_LOGGER = logging.getLogger("setup")


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--language",
help="Voice language to download",
required=True
)
parser.add_argument(
"--model",
help="Model to download",
required=True
)
parser.add_argument(
"--destination", help="Path to destination directory (default: share)"
)

args = parser.parse_args()
logging.basicConfig(level=logging.INFO)

with tempfile.NamedTemporaryFile() as latest_silero_models:
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
latest_silero_models.name,
progress=False)
models = OmegaConf.load(latest_silero_models.name)

if args.destination:
data_path = Path(args.destination)
else:
data_path = _DIR.parent.parent.parent.parent / "data" / "tts" / "silero_tts" / "models"

model_path = data_path / args.language

model_path.mkdir(parents=True, exist_ok=True)

torch.hub.download_url_to_file(models.tts_models[args.language][args.model].latest.package,
model_path / f'{args.model}.pt')


if __name__ == "__main__":
main()
32 changes: 32 additions & 0 deletions programs/tts/silero_tts/bin/list_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3
import logging
import tempfile
from pathlib import Path

import torch
from omegaconf import OmegaConf

_DIR = Path(__file__).parent
_LOGGER = logging.getLogger("list_models")


def main() -> None:
"""Main method."""
with tempfile.NamedTemporaryFile() as latest_silero_models:
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
latest_silero_models.name,
progress=False)
models = OmegaConf.load(latest_silero_models.name)

available_languages = list(models.tts_models.keys())
print(f'Available languages {available_languages}')

for lang in available_languages:
_models = list(models.tts_models.get(lang).keys())
print(f'Available models for {lang}: {_models}')


# -----------------------------------------------------------------------------

if __name__ == "__main__":
main()
127 changes: 127 additions & 0 deletions programs/tts/silero_tts/bin/silero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
import argparse
import logging
import sys
from pathlib import Path

import numpy as np
import torch

from rhasspy3.audio import AudioStart, DEFAULT_SAMPLES_PER_CHUNK, AudioChunk, AudioStop
from rhasspy3.event import write_event, read_event
from rhasspy3.tts import Synthesize

_FILE = Path(__file__)
_DIR = _FILE.parent
_LOGGER = logging.getLogger(_FILE.stem)


class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""

def __init__(self, logger, log_level=logging.INFO):
self.logger = logger
self.log_level = log_level
self.linebuf = ''

def write(self, buf):
for line in buf.rstrip().splitlines():
self.logger.log(self.log_level, line.rstrip())

def flush(self):
pass


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--language", required=True, help="Language to use")
parser.add_argument("--model", required=True, help="Model to use")
parser.add_argument("--sample_rate", help="Sample rate", default=48000)
parser.add_argument("--speaker", help="Voice to use", default='random')
parser.add_argument("--put_accent", help="Add accent", default=True)
parser.add_argument("--put_yo", help="Put Yo", default=True)
parser.add_argument("--samples-per-chunk", type=int, default=DEFAULT_SAMPLES_PER_CHUNK)
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
parser.add_argument("--destination", help="Path to destination directory")
parser.add_argument("--voice", help="Saved voice model")
args = parser.parse_args()

logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)

sys_stdout = sys.stdout
sys.stdout = StreamToLogger(_LOGGER, logging.INFO)

if args.destination:
data_path = Path(args.destination)
else:
data_path = _DIR.parent.parent.parent.parent / "data" / "tts" / "silero_tts"

model_path = data_path / "models" / args.language / f'{args.model}.pt'

model_params = {
'speaker': args.speaker,
'sample_rate': args.sample_rate,
'put_accent': args.put_accent,
'put_yo': args.put_yo
}

if args.voice:
voice_path = Path(args.voice)
if voice_path.is_absolute():
model_params['voice_path'] = voice_path
else:
model_params['voice_path'] = data_path / 'voices' / voice_path

device = torch.device('cpu')
model = torch.package.PackageImporter(model_path).load_pickle("tts_models", "model")
model.to(device)
# Listen for events
try:
while True:
event = read_event()
if event is None:
break

if Synthesize.is_type(event.type):
synthesize = Synthesize.from_event(event)
_LOGGER.debug("synthesize: text='%s'", synthesize.text)

audio = model.apply_tts(text=synthesize.text, **model_params)

width = 2
channels = 1
timestamp = 0
rate = args.sample_rate
bytes_per_chunk = args.samples_per_chunk * width

start_event = AudioStart(rate, width, channels, timestamp=timestamp)
write_event(start_event.event(), sys_stdout.buffer)
_LOGGER.debug(start_event)

# Audio
audio_bytes = (32767 * audio).numpy().astype(np.int16).tobytes()

while audio_bytes:
chunk = AudioChunk(
rate,
width,
channels,
audio_bytes[:bytes_per_chunk],
timestamp=timestamp,
)
write_event(chunk.event(), sys_stdout.buffer)
timestamp += chunk.milliseconds
audio_bytes = audio_bytes[bytes_per_chunk:]

write_event(AudioStop(timestamp=timestamp).event(), sys_stdout.buffer)

except KeyboardInterrupt:
pass


# -----------------------------------------------------------------------------

if __name__ == "__main__":
main()
129 changes: 129 additions & 0 deletions programs/tts/silero_tts/bin/silero_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/env python3
import argparse
import json
import logging
import os
import socket
from pathlib import Path

import numpy as np
import torch

from rhasspy3.audio import AudioStart, DEFAULT_SAMPLES_PER_CHUNK, AudioChunk, AudioStop
from rhasspy3.event import write_event

_FILE = Path(__file__)
_DIR = _FILE.parent
_LOGGER = logging.getLogger(_FILE.stem)


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--language", required=True, help="Language to use")
parser.add_argument("--model", required=True, help="Model to use")
parser.add_argument("--sample_rate", help="Sample rate", default=48000)
parser.add_argument("--speaker", help="Voice to use", default='random')
parser.add_argument("--put_accent", help="Add accent", default=True)
parser.add_argument("--put_yo", help="Put Yo", default=True)
parser.add_argument("--socketfile", required=True, help="Path to Unix domain socket file")
parser.add_argument("--samples-per-chunk", type=int, default=DEFAULT_SAMPLES_PER_CHUNK)
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
parser.add_argument("--destination", help="Path to destination directory (default: share)")
parser.add_argument("--voice", help="Saved voice model")
args = parser.parse_args()

logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)

# Need to unlink socket if it exists
try:
os.unlink(args.socketfile)
except OSError:
pass

try:
if args.destination:
data_path = Path(args.destination)
else:
data_path = _DIR.parent.parent.parent.parent / "data" / "tts" / "silero_tts"

model_path = data_path / "models" / args.language / f'{args.model}.pt'

model_params = {
'speaker': args.speaker,
'sample_rate': args.sample_rate,
'put_accent': args.put_accent,
'put_yo': args.put_yo
}

if args.voice:
voice_path = Path(args.voice)
if voice_path.is_absolute():
model_params['voice_path'] = voice_path
else:
model_params['voice_path'] = data_path / 'voices' / voice_path

device = torch.device('cpu')
model = torch.package.PackageImporter(model_path).load_pickle("tts_models", "model")
model.to(device)
# Create socket server
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(args.socketfile)
sock.listen()

# Listen for connections
while True:
try:
connection, client_address = sock.accept()
_LOGGER.debug("Connection from %s", client_address)
with connection, connection.makefile(mode="rwb") as conn_file:
while True:
event_info = json.loads(conn_file.readline())
event_type = event_info["type"]

if event_type != "synthesize":
continue

raw_text = event_info["data"]["text"]
text = raw_text.strip()

_LOGGER.debug("synthesize: raw_text=%s, text='%s'", raw_text, text)

audio = model.apply_tts(text=text, **model_params)

width = 2
channels = 1
timestamp = 0
rate = args.sample_rate
bytes_per_chunk = args.samples_per_chunk * width

write_event(AudioStart(rate, width, channels, timestamp=timestamp).event(), conn_file)

# Audio
audio_bytes = (32767 * audio).numpy().astype(np.int16).tobytes()

while audio_bytes:
chunk = AudioChunk(
rate,
width,
channels,
audio_bytes[:bytes_per_chunk],
timestamp=timestamp,
)
write_event(chunk.event(), conn_file)
timestamp += chunk.milliseconds
audio_bytes = audio_bytes[bytes_per_chunk:]

write_event(AudioStop(timestamp=timestamp).event(), conn_file)
break
except KeyboardInterrupt:
break
except Exception:
_LOGGER.exception("Error communicating with socket client")
finally:
os.unlink(args.socketfile)


# -----------------------------------------------------------------------------

if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions programs/tts/silero_tts/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torchaudio
omegaconf
numpy
17 changes: 17 additions & 0 deletions programs/tts/silero_tts/script/download
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env bash
set -eo pipefail

# Directory of *this* script
this_dir="$( cd "$( dirname "$0" )" && pwd )"

# Base directory of repo
base_dir="$(realpath "${this_dir}/..")"

# Path to virtual environment
: "${venv:=${base_dir}/.venv}"

if [ -d "${venv}" ]; then
source "${venv}/bin/activate"
fi

python3 "${base_dir}/bin/download.py" "$@"
Loading