Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Fix CommonVoice dataset for Speech Recognition problem #1852

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensor2tensor/bin/t2t_datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
"Temporary storage directory.")
flags.DEFINE_string("problem", "",
"The name of the problem to generate data for.")
flags.DEFINE_string("language", "en",
"Common Voice language code.")
flags.DEFINE_string("exclude_problems", "",
"Comma-separates list of problems to exclude.")
flags.DEFINE_integer(
Expand Down
228 changes: 90 additions & 138 deletions tensor2tensor/data_generators/common_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,56 @@

import tensorflow.compat.v1 as tf

_COMMONVOICE_URL = "https://common-voice-data-download.s3.amazonaws.com/cv_corpus_v1.tar.gz" # pylint: disable=line-too-long

_COMMONVOICE_TRAIN_DATASETS = ["cv-valid-train", "cv-other-train"]
_COMMONVOICE_DEV_DATASETS = ["cv-valid-dev", "cv-other-dev"]
_COMMONVOICE_TEST_DATASETS = ["cv-valid-test", "cv-other-test"]


def _collect_data(directory):
"""Traverses directory collecting input and target files.

Args:
directory: base path to extracted audio and transcripts.
Returns:
list of (media_base, media_filepath, label) tuples
"""
# Returns:
data_files = []
transcripts = [
filename for filename in os.listdir(directory)
if filename.endswith(".csv")
]
for transcript in transcripts:
transcript_path = os.path.join(directory, transcript)
with open(transcript_path, "r") as transcript_file:
transcript_reader = csv.reader(transcript_file)
# skip header
_ = next(transcript_reader)
for transcript_line in transcript_reader:
media_name, label = transcript_line[0:2]
filename = os.path.join(directory, media_name)
data_files.append((media_name, filename, label))
return data_files

_EXT_ARCHIVE = ".tar.gz"
_CORPUS_VERSION = "cv-corpus-5-2020-06-22"
_BASE_URL = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/" + _CORPUS_VERSION # pylint: disable=line-too-long

_COMMONVOICE_TRAIN_DATASETS = ["validated", "train"]
_COMMONVOICE_DEV_DATASETS = ["dev", "other"]
_COMMONVOICE_TEST_DATASETS = ["test"]

_LANGUAGES = {
"tt",
"en",
"de",
"fr",
"cy",
"br",
"cv",
"tr",
"ky",
"ga-IE",
"kab",
"ca",
"zh-TW",
"sl",
"it",
"nl",
"cnh",
"eo",
"et",
"fa",
"pt",
"eu",
"es",
"zh-CN",
"mn",
"sah",
"dv",
"rw",
"sv-SE",
"ru",
"id",
"ar",
"ta",
"ia",
"lv",
"ja",
"vot",
"ab",
"zh-HK",
"rm-sursilv"
}

def _file_exists(path, filename):
"""Checks if the filename exists under the path."""
Expand All @@ -73,17 +90,17 @@ def _file_exists(path, filename):

def _is_relative(path, filename):
"""Checks if the filename is relative, not absolute."""
return os.path.abspath(os.path.join(path, filename)).startswith(path)
return not os.path.abspath(os.path.join(path, filename)).startswith(path)


@registry.register_problem()
class CommonVoice(speech_recognition.SpeechRecognitionProblem):
"""Problem spec for Commonvoice using clean and noisy data."""

# Select only the clean data
TRAIN_DATASETS = _COMMONVOICE_TRAIN_DATASETS[:1]
TRAIN_DATASETS = _COMMONVOICE_TRAIN_DATASETS
DEV_DATASETS = _COMMONVOICE_DEV_DATASETS[:1]
TEST_DATASETS = _COMMONVOICE_TEST_DATASETS[:1]
TEST_DATASETS = _COMMONVOICE_TEST_DATASETS

@property
def num_shards(self):
Expand All @@ -109,18 +126,18 @@ def use_train_shards_for_dev(self):
def generator(self,
data_dir,
tmp_dir,
datasets,
eos_list=None,
start_from=0,
how_many=0):
del eos_list
i = 0

filename = os.path.basename(_COMMONVOICE_URL)
language,
datasets):
if language in _LANGUAGES:
_CODE = language
else:
_CODE = "en"
_URL = _BASE_URL + _CODE + _EXT_ARCHIVE
filename = os.path.basename(_URL)
compressed_file = generator_utils.maybe_download(tmp_dir, filename,
_COMMONVOICE_URL)
_URL)

read_type = "r:gz" if filename.endswith(".tgz") else "r"
read_type = "r:gz" if filename.endswith(".tar.gz") else "r"
with tarfile.open(compressed_file, read_type) as corpus_tar:
# Create a subset of files that don't already exist.
# tarfile.extractall errors when encountering an existing file
Expand All @@ -132,29 +149,33 @@ def generator(self,
]
corpus_tar.extractall(tmp_dir, members=members)

raw_data_dir = os.path.join(tmp_dir, "cv_corpus_v1")
data_tuples = _collect_data(raw_data_dir)
raw_data_dir = os.path.join(tmp_dir, _CORPUS_VERSION + _CODE + "/")
encoders = self.feature_encoders(data_dir)
audio_encoder = encoders["waveforms"]
text_encoder = encoders["targets"]

for dataset in datasets:
data_tuples = (tup for tup in data_tuples if tup[0].startswith(dataset))
for utt_id, media_file, text_data in tqdm.tqdm(
sorted(data_tuples)[start_from:]):
if how_many > 0 and i == how_many:
return
i += 1
wav_data = audio_encoder.encode(media_file)
yield {
"waveforms": wav_data,
"waveform_lens": [len(wav_data)],
"targets": text_encoder.encode(text_data),
"raw_transcript": [text_data],
"utt_id": [utt_id],
"spk_id": ["unknown"],
}

def generate_data(self, data_dir, tmp_dir, task_id=-1):
full_dataset_filename = dataset + ".tsv"
with tf.io.gfile.GFile(os.path.join(raw_data_dir, full_dataset_filename)) as file_:
dataset = csv.DictReader(file_, delimiter="\t")
f_length = len(file_.readlines())
file_.seek(0, 0)
for i, row in tqdm.tqdm(enumerate(dataset), total=f_length):
file_path = os.path.join(os.path.join(raw_data_dir, "clips/"), row["path"])
if tf.io.gfile.exists(file_path):
try:
wav_data = audio_encoder.encode(file_path)
yield {
"waveforms": wav_data,
"waveform_lens": [len(wav_data)],
"targets": text_encoder.encode(row["sentence"]),
"raw_transcript": [row["sentence"]],
"utt_id": row["client_id"]
}
except Exception as e:
print(e)

def generate_data(self, data_dir, tmp_dir, language, task_id=-1):
train_paths = self.training_filepaths(
data_dir, self.num_shards, shuffled=False)
dev_paths = self.dev_filepaths(
Expand All @@ -163,86 +184,17 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
data_dir, self.num_test_shards, shuffled=True)

generator_utils.generate_files(
self.generator(data_dir, tmp_dir, self.TEST_DATASETS), test_paths)
self.generator(data_dir, tmp_dir, language, self.TEST_DATASETS), test_paths)

if self.use_train_shards_for_dev:
all_paths = train_paths + dev_paths
generator_utils.generate_files(
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), all_paths)
self.generator(data_dir, tmp_dir, language, self.TRAIN_DATASETS), all_paths)
generator_utils.shuffle_dataset(all_paths)
else:
generator_utils.generate_dataset_and_shuffle(
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), train_paths,
self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths)


@registry.register_problem()
class CommonVoiceTrainFullTestClean(CommonVoice):
"""Problem to train on full set, but evaluate on clean data only."""

def training_filepaths(self, data_dir, num_shards, shuffled):
return CommonVoice.training_filepaths(self, data_dir, num_shards, shuffled)

def dev_filepaths(self, data_dir, num_shards, shuffled):
return CommonVoiceClean.dev_filepaths(self, data_dir, num_shards, shuffled)

def test_filepaths(self, data_dir, num_shards, shuffled):
return CommonVoiceClean.test_filepaths(self, data_dir, num_shards, shuffled)

def generate_data(self, data_dir, tmp_dir, task_id=-1):
raise Exception("Generate Commonvoice and Commonvoice_clean data.")

def filepattern(self, data_dir, mode, shard=None):
"""Get filepattern for data files for mode.

Matches mode to a suffix.
* DatasetSplit.TRAIN: train
* DatasetSplit.EVAL: dev
* DatasetSplit.TEST: test
* tf.estimator.ModeKeys.PREDICT: dev

Args:
data_dir: str, data directory.
mode: DatasetSplit
shard: int, if provided, will only read data from the specified shard.

Returns:
filepattern str
"""
shard_str = "-%05d" % shard if shard is not None else ""
if mode == problem.DatasetSplit.TRAIN:
path = os.path.join(data_dir, "common_voice")
suffix = "train"
elif mode in [problem.DatasetSplit.EVAL, tf.estimator.ModeKeys.PREDICT]:
path = os.path.join(data_dir, "common_voice_clean")
suffix = "dev"
else:
assert mode == problem.DatasetSplit.TEST
path = os.path.join(data_dir, "common_voice_clean")
suffix = "test"

return "%s-%s%s*" % (path, suffix, shard_str)


@registry.register_problem()
class CommonVoiceClean(CommonVoice):
"""Problem spec for Common Voice using clean train and clean eval data."""

# Select only the "clean" data (crowdsourced quality control).
TRAIN_DATASETS = _COMMONVOICE_TRAIN_DATASETS[:1]
DEV_DATASETS = _COMMONVOICE_DEV_DATASETS[:1]
TEST_DATASETS = _COMMONVOICE_TEST_DATASETS[:1]


@registry.register_problem()
class CommonVoiceNoisy(CommonVoice):
"""Problem spec for Common Voice using noisy train and noisy eval data."""

# Select only the "other" data.
TRAIN_DATASETS = _COMMONVOICE_TRAIN_DATASETS[1:]
DEV_DATASETS = _COMMONVOICE_DEV_DATASETS[1:]
TEST_DATASETS = _COMMONVOICE_TEST_DATASETS[1:]

self.generator(data_dir, tmp_dir, language, self.TRAIN_DATASETS), train_paths,
self.generator(data_dir, tmp_dir, language, self.DEV_DATASETS), dev_paths)

def set_common_voice_length_hparams(hparams):
hparams.max_length = 1650 * 80
Expand Down