diff --git a/tensor2tensor/bin/t2t_datagen.py b/tensor2tensor/bin/t2t_datagen.py index c974acdfd..3888a5bac 100644 --- a/tensor2tensor/bin/t2t_datagen.py +++ b/tensor2tensor/bin/t2t_datagen.py @@ -63,6 +63,8 @@ "Temporary storage directory.") flags.DEFINE_string("problem", "", "The name of the problem to generate data for.") +flags.DEFINE_string("language", "en", + "Common Voice language code.") flags.DEFINE_string("exclude_problems", "", "Comma-separates list of problems to exclude.") flags.DEFINE_integer( diff --git a/tensor2tensor/data_generators/common_voice.py b/tensor2tensor/data_generators/common_voice.py index a0806659a..61a52ac59 100644 --- a/tensor2tensor/data_generators/common_voice.py +++ b/tensor2tensor/data_generators/common_voice.py @@ -32,39 +32,56 @@ import tensorflow.compat.v1 as tf -_COMMONVOICE_URL = "https://common-voice-data-download.s3.amazonaws.com/cv_corpus_v1.tar.gz" # pylint: disable=line-too-long - -_COMMONVOICE_TRAIN_DATASETS = ["cv-valid-train", "cv-other-train"] -_COMMONVOICE_DEV_DATASETS = ["cv-valid-dev", "cv-other-dev"] -_COMMONVOICE_TEST_DATASETS = ["cv-valid-test", "cv-other-test"] - - -def _collect_data(directory): - """Traverses directory collecting input and target files. - - Args: - directory: base path to extracted audio and transcripts. - Returns: - list of (media_base, media_filepath, label) tuples - """ - # Returns: - data_files = [] - transcripts = [ - filename for filename in os.listdir(directory) - if filename.endswith(".csv") - ] - for transcript in transcripts: - transcript_path = os.path.join(directory, transcript) - with open(transcript_path, "r") as transcript_file: - transcript_reader = csv.reader(transcript_file) - # skip header - _ = next(transcript_reader) - for transcript_line in transcript_reader: - media_name, label = transcript_line[0:2] - filename = os.path.join(directory, media_name) - data_files.append((media_name, filename, label)) - return data_files - +_EXT_ARCHIVE = ".tar.gz" +_CORPUS_VERSION = "cv-corpus-5-2020-06-22" +_BASE_URL = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/" + _CORPUS_VERSION # pylint: disable=line-too-long + +_COMMONVOICE_TRAIN_DATASETS = ["validated", "train"] +_COMMONVOICE_DEV_DATASETS = ["dev", "other"] +_COMMONVOICE_TEST_DATASETS = ["test"] + +_LANGUAGES = { + "tt", + "en", + "de", + "fr", + "cy", + "br", + "cv", + "tr", + "ky", + "ga-IE", + "kab", + "ca", + "zh-TW", + "sl", + "it", + "nl", + "cnh", + "eo", + "et", + "fa", + "pt", + "eu", + "es", + "zh-CN", + "mn", + "sah", + "dv", + "rw", + "sv-SE", + "ru", + "id", + "ar", + "ta", + "ia", + "lv", + "ja", + "vot", + "ab", + "zh-HK", + "rm-sursilv" +} def _file_exists(path, filename): """Checks if the filename exists under the path.""" @@ -73,7 +90,7 @@ def _file_exists(path, filename): def _is_relative(path, filename): """Checks if the filename is relative, not absolute.""" - return os.path.abspath(os.path.join(path, filename)).startswith(path) + return not os.path.abspath(os.path.join(path, filename)).startswith(path) @registry.register_problem() @@ -81,9 +98,9 @@ class CommonVoice(speech_recognition.SpeechRecognitionProblem): """Problem spec for Commonvoice using clean and noisy data.""" # Select only the clean data - TRAIN_DATASETS = _COMMONVOICE_TRAIN_DATASETS[:1] + TRAIN_DATASETS = _COMMONVOICE_TRAIN_DATASETS DEV_DATASETS = _COMMONVOICE_DEV_DATASETS[:1] - TEST_DATASETS = _COMMONVOICE_TEST_DATASETS[:1] + TEST_DATASETS = _COMMONVOICE_TEST_DATASETS @property def num_shards(self): @@ -109,18 +126,18 @@ def use_train_shards_for_dev(self): def generator(self, data_dir, tmp_dir, - datasets, - eos_list=None, - start_from=0, - how_many=0): - del eos_list - i = 0 - - filename = os.path.basename(_COMMONVOICE_URL) + language, + datasets): + if language in _LANGUAGES: + _CODE = language + else: + _CODE = "en" + _URL = _BASE_URL + _CODE + _EXT_ARCHIVE + filename = os.path.basename(_URL) compressed_file = generator_utils.maybe_download(tmp_dir, filename, - _COMMONVOICE_URL) + _URL) - read_type = "r:gz" if filename.endswith(".tgz") else "r" + read_type = "r:gz" if filename.endswith(".tar.gz") else "r" with tarfile.open(compressed_file, read_type) as corpus_tar: # Create a subset of files that don't already exist. # tarfile.extractall errors when encountering an existing file @@ -132,29 +149,33 @@ def generator(self, ] corpus_tar.extractall(tmp_dir, members=members) - raw_data_dir = os.path.join(tmp_dir, "cv_corpus_v1") - data_tuples = _collect_data(raw_data_dir) + raw_data_dir = os.path.join(tmp_dir, _CORPUS_VERSION + _CODE + "/") encoders = self.feature_encoders(data_dir) audio_encoder = encoders["waveforms"] text_encoder = encoders["targets"] + for dataset in datasets: - data_tuples = (tup for tup in data_tuples if tup[0].startswith(dataset)) - for utt_id, media_file, text_data in tqdm.tqdm( - sorted(data_tuples)[start_from:]): - if how_many > 0 and i == how_many: - return - i += 1 - wav_data = audio_encoder.encode(media_file) - yield { - "waveforms": wav_data, - "waveform_lens": [len(wav_data)], - "targets": text_encoder.encode(text_data), - "raw_transcript": [text_data], - "utt_id": [utt_id], - "spk_id": ["unknown"], - } - - def generate_data(self, data_dir, tmp_dir, task_id=-1): + full_dataset_filename = dataset + ".tsv" + with tf.io.gfile.GFile(os.path.join(raw_data_dir, full_dataset_filename)) as file_: + dataset = csv.DictReader(file_, delimiter="\t") + f_length = len(file_.readlines()) + file_.seek(0, 0) + for i, row in tqdm.tqdm(enumerate(dataset), total=f_length): + file_path = os.path.join(os.path.join(raw_data_dir, "clips/"), row["path"]) + if tf.io.gfile.exists(file_path): + try: + wav_data = audio_encoder.encode(file_path) + yield { + "waveforms": wav_data, + "waveform_lens": [len(wav_data)], + "targets": text_encoder.encode(row["sentence"]), + "raw_transcript": [row["sentence"]], + "utt_id": row["client_id"] + } + except Exception as e: + print(e) + + def generate_data(self, data_dir, tmp_dir, language, task_id=-1): train_paths = self.training_filepaths( data_dir, self.num_shards, shuffled=False) dev_paths = self.dev_filepaths( @@ -163,86 +184,17 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): data_dir, self.num_test_shards, shuffled=True) generator_utils.generate_files( - self.generator(data_dir, tmp_dir, self.TEST_DATASETS), test_paths) + self.generator(data_dir, tmp_dir, language, self.TEST_DATASETS), test_paths) if self.use_train_shards_for_dev: all_paths = train_paths + dev_paths generator_utils.generate_files( - self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), all_paths) + self.generator(data_dir, tmp_dir, language, self.TRAIN_DATASETS), all_paths) generator_utils.shuffle_dataset(all_paths) else: generator_utils.generate_dataset_and_shuffle( - self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), train_paths, - self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths) - - -@registry.register_problem() -class CommonVoiceTrainFullTestClean(CommonVoice): - """Problem to train on full set, but evaluate on clean data only.""" - - def training_filepaths(self, data_dir, num_shards, shuffled): - return CommonVoice.training_filepaths(self, data_dir, num_shards, shuffled) - - def dev_filepaths(self, data_dir, num_shards, shuffled): - return CommonVoiceClean.dev_filepaths(self, data_dir, num_shards, shuffled) - - def test_filepaths(self, data_dir, num_shards, shuffled): - return CommonVoiceClean.test_filepaths(self, data_dir, num_shards, shuffled) - - def generate_data(self, data_dir, tmp_dir, task_id=-1): - raise Exception("Generate Commonvoice and Commonvoice_clean data.") - - def filepattern(self, data_dir, mode, shard=None): - """Get filepattern for data files for mode. - - Matches mode to a suffix. - * DatasetSplit.TRAIN: train - * DatasetSplit.EVAL: dev - * DatasetSplit.TEST: test - * tf.estimator.ModeKeys.PREDICT: dev - - Args: - data_dir: str, data directory. - mode: DatasetSplit - shard: int, if provided, will only read data from the specified shard. - - Returns: - filepattern str - """ - shard_str = "-%05d" % shard if shard is not None else "" - if mode == problem.DatasetSplit.TRAIN: - path = os.path.join(data_dir, "common_voice") - suffix = "train" - elif mode in [problem.DatasetSplit.EVAL, tf.estimator.ModeKeys.PREDICT]: - path = os.path.join(data_dir, "common_voice_clean") - suffix = "dev" - else: - assert mode == problem.DatasetSplit.TEST - path = os.path.join(data_dir, "common_voice_clean") - suffix = "test" - - return "%s-%s%s*" % (path, suffix, shard_str) - - -@registry.register_problem() -class CommonVoiceClean(CommonVoice): - """Problem spec for Common Voice using clean train and clean eval data.""" - - # Select only the "clean" data (crowdsourced quality control). - TRAIN_DATASETS = _COMMONVOICE_TRAIN_DATASETS[:1] - DEV_DATASETS = _COMMONVOICE_DEV_DATASETS[:1] - TEST_DATASETS = _COMMONVOICE_TEST_DATASETS[:1] - - -@registry.register_problem() -class CommonVoiceNoisy(CommonVoice): - """Problem spec for Common Voice using noisy train and noisy eval data.""" - - # Select only the "other" data. - TRAIN_DATASETS = _COMMONVOICE_TRAIN_DATASETS[1:] - DEV_DATASETS = _COMMONVOICE_DEV_DATASETS[1:] - TEST_DATASETS = _COMMONVOICE_TEST_DATASETS[1:] - + self.generator(data_dir, tmp_dir, language, self.TRAIN_DATASETS), train_paths, + self.generator(data_dir, tmp_dir, language, self.DEV_DATASETS), dev_paths) def set_common_voice_length_hparams(hparams): hparams.max_length = 1650 * 80