forked from facebookresearch/CPC_audio
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
12c407b
commit 1d636c6
Showing
3 changed files
with
284 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
#!/usr/bin/env python3 -u | ||
# !/usr/bin/env python3 -u | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
import os | ||
import sys | ||
import argparse | ||
from itertools import chain | ||
from pathlib import Path | ||
import time | ||
import copy | ||
import numpy as np | ||
import soundfile as sf | ||
|
||
from cpc.feature_loader import loadModel, FeatureModule | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
level=os.environ.get("LOGLEVEL", "INFO").upper(), | ||
stream=sys.stdout, | ||
) | ||
logger = logging.getLogger("zerospeech2021 abx") | ||
|
||
def parse_args(): | ||
# Run parameters | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("path_checkpoint", type=str, | ||
help="Path to the trained fairseq wav2vec2.0 model.") | ||
parser.add_argument("path_data", type=str, | ||
help="Path to the dataset that we want to compute ABX for.") | ||
parser.add_argument("path_output_dir", type=str, | ||
help="Path to the output directory.") | ||
parser.add_argument("--debug", action="store_true", | ||
help="Load only a very small amount of files for " | ||
"debugging purposes.") | ||
parser.add_argument("--cpu", action="store_true", | ||
help="Run on a cpu machine.") | ||
parser.add_argument("--file_extension", type=str, default="wav", | ||
help="Extension of the audio files in the dataset (default: wav).") | ||
parser.add_argument("--no_test", action="store_true", | ||
help="Don't compute embeddings for test-* parts of dataset") | ||
parser.add_argument('--gru_level', type=int, default=-1, | ||
help='Hidden level of the LSTM autoregressive model to be taken' | ||
'(default: -1, last layer).') | ||
parser.add_argument('--nullspace', action='store_true', | ||
help="Additionally load nullspace") | ||
return parser.parse_args() | ||
|
||
def main(): | ||
# Parse and print args | ||
args = parse_args() | ||
logger.info(args) | ||
|
||
# Load the model | ||
print("") | ||
print(f"Loading model from {args.path_checkpoint}") | ||
|
||
if args.gru_level is not None and args.gru_level > 0: | ||
updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level) | ||
else: | ||
updateConfig = None | ||
|
||
model = loadModel([args.path_checkpoint], load_nullspace=args.nullspace, updateConfig=updateConfig)[0] | ||
|
||
if args.gru_level is not None and args.gru_level > 0: | ||
# Keep hidden units at LSTM layers on sequential batches | ||
if args.nullspace: | ||
model.cpc.gAR.keepHidden = True | ||
else: | ||
model.gAR.keepHidden = True | ||
|
||
device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu" | ||
|
||
# Register the hooks | ||
layer_outputs = {} | ||
def get_layer_output(name): | ||
def hook(model, input, output): | ||
if type(output) is tuple: | ||
layer_outputs[name] = output[0].detach().squeeze(1).cpu().numpy() | ||
elif type(output) is dict: | ||
layer_outputs[name] = output["x"].detach().squeeze(0).cpu().numpy() | ||
else: | ||
layer_outputs[name] = output.detach().squeeze(0).cpu().numpy() | ||
return hook | ||
|
||
layer_names = [] | ||
layer_name = os.path.basename(os.path.dirname(args.path_checkpoint)) | ||
layer_names.append(layer_name) | ||
if not args.nullspace: | ||
model.gAR.register_forward_hook(get_layer_output(layer_name)) | ||
else: | ||
model.nullspace.register_forward_hook(get_layer_output(layer_name)) | ||
|
||
model = model.eval().to(device) | ||
print("Model loaded!") | ||
print(model) | ||
|
||
# Extract values from chosen layers and save them to files | ||
phonetic = "phonetic" | ||
datasets_path = os.path.join(args.path_data, phonetic) | ||
datasets = os.listdir(datasets_path) | ||
datasets = [dataset for dataset in datasets if not args.no_test or not dataset.startswith("test")] | ||
print(datasets) | ||
|
||
with torch.no_grad(): | ||
for dataset in datasets: | ||
print("> {}".format(dataset)) | ||
dataset_path = os.path.join(datasets_path, dataset) | ||
files = [f for f in os.listdir(dataset_path) if f.endswith(args.file_extension)] | ||
for i, f in enumerate(files): | ||
print("Progress {:2.1%}".format(i / len(files)), end="\r") | ||
input_f = os.path.join(dataset_path, f) | ||
x, sample_rate = sf.read(input_f) | ||
x = torch.tensor(x).float().reshape(1,1,-1).to(device) | ||
output = model(x, None)[0] | ||
|
||
for layer_name, value in layer_outputs.items(): | ||
output_dir = os.path.join(args.path_output_dir, layer_name, phonetic, dataset) | ||
Path(output_dir).mkdir(parents=True, exist_ok=True) | ||
out_f = os.path.join(output_dir, os.path.splitext(f)[0] + ".txt") | ||
np.savetxt(out_f, value) | ||
|
||
if __name__ == "__main__": | ||
#import ptvsd | ||
#ptvsd.enable_attach(('0.0.0.0', 7310)) | ||
#print("Attach debugger now") | ||
#ptvsd.wait_for_attach() | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
########## CHANGE THIS ################## | ||
ZEROSPEECH_EVAL_ENV=zerospeech2021 # Where the zerospeech2021-evaluate is installed | ||
CPC_ENV=202010-fairseq-c11 | ||
CONDA_PATH=/pio/scratch/2/i273233/miniconda3 | ||
######################################### | ||
|
||
DATASET_PATH=false | ||
ORIGINAL_DATASET_PATH=false | ||
CHECKPOINT_PATH=false | ||
OUTPUT_DIR=false | ||
NULLSPACE=false | ||
NO_TEST=false | ||
|
||
print_usage() { | ||
echo -e "Usage: ./eval_abx.sh" | ||
echo -e "\t-d DATASET_PATH" | ||
echo -e "\t-r ORIGINAL_DATASET_PATH" | ||
echo -e "\t-c CHECKPOINT_PATH" | ||
echo -e "\t-o OUTPUT_DIR" | ||
echo -e "OPTIONAL FLAGS:" | ||
echo -e "\t-n (Load a model with nullspace)" | ||
echo -e "\t-a CONDA_PATH" | ||
echo -e "\t-e CPC_ENV" | ||
echo -e "\t-z ZEROSPEECH_EVAL_ENV (The conda environment where the zerospeech2021-evaluate is installed)" | ||
echo -e "\t-t (Do not compute embeddings for test set)" | ||
} | ||
|
||
while getopts 'd:r:c:o:na:e:z:t' flag; do | ||
case "${flag}" in | ||
d) DATASET_PATH="${OPTARG}" ;; | ||
r) ORIGINAL_DATASET_PATH="${OPTARG}" ;; | ||
c) CHECKPOINT_PATH="${OPTARG}" ;; | ||
o) OUTPUT_DIR="${OPTARG}" ;; | ||
n) NULLSPACE=true ;; | ||
a) CONDA_PATH="${OPTARG}" ;; | ||
e) CPC_ENV="${OPTARG}" ;; | ||
z) ZEROSPEECH_EVAL_ENV="${OPTARG}" ;; | ||
t) NO_TEST=true ;; | ||
*) print_usage | ||
exit 1 ;; | ||
esac | ||
done | ||
|
||
echo $DATASET_PATH $ORIGINAL_DATASET_PATH $CHECKPOINT_PATH $OUTPUT_DIR $NULLSPACE $CONDA_PATH $CPC_ENV $ZEROSPEECH_EVAL_ENV $NO_TEST | ||
|
||
if [[ $DATASET_PATH == false || $ORIGINAL_DATASET_PATH == false || $CHECKPOINT_PATH == false || $OUTPUT_DIR == false ]] | ||
then | ||
echo "Either DATASET_PATH or ORIGINAL_DATASET_PATH or CHECKPOINT_PATH or OUTPUT_DIR is not set." | ||
print_usage | ||
exit 1 | ||
fi | ||
|
||
SCRIPT_PATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" | ||
|
||
results=$OUTPUT_DIR/results | ||
embeddings=$OUTPUT_DIR/embeddings | ||
mkdir -p embeddings | ||
|
||
source $CONDA_PATH/etc/profile.d/conda.sh | ||
SAVED_ENV=$(conda info | sed -n 's/\( \)*active environment : //p') | ||
echo SAVED_ENV: $SAVED_ENV | ||
|
||
ENV_TO_ACTIVATE=$CPC_ENV | ||
conda activate $ENV_TO_ACTIVATE | ||
|
||
params="" | ||
if [[ $NULLSPACE == true ]] | ||
then | ||
params="${params} --nullspace" | ||
fi | ||
|
||
if [[ $NO_TEST == true ]] | ||
then | ||
params="${params} --no_test" | ||
fi | ||
echo "Params: $params" | ||
|
||
echo "$SCRIPT_PATH/embeddings_abx.py" | ||
python $SCRIPT_PATH/embeddings_abx.py $CHECKPOINT_PATH $DATASET_PATH $embeddings --gru_level 2 $params | ||
|
||
directories=("dev-clean" "dev-other") | ||
if [[ $NO_TEST == false ]] | ||
then | ||
directories+=("test-clean" "test-other") | ||
fi | ||
echo "Directories: ${directories[@]}" | ||
|
||
for i in `basename -a $(ls -d $embeddings/*/)` | ||
do | ||
for directory in ${directories[@]} | ||
do | ||
for file in `ls $embeddings/$i/phonetic/$directory` | ||
do | ||
filename_no_ext="${file%.*}" | ||
if [[ ! -f "$ORIGINAL_DATASET_PATH/phonetic/$directory/${filename_no_ext}.wav" ]] | ||
then | ||
rm $embeddings/$i/phonetic/$directory/$file | ||
fi | ||
done | ||
done | ||
done | ||
|
||
conda activate $ZEROSPEECH_EVAL_ENV | ||
|
||
frame_shift="0.01" | ||
echo "Frame shift is ${frame_shift}s" | ||
|
||
metrics=("cosine" "euclidean") | ||
for metric in ${metrics[@]} | ||
do | ||
cat > $embeddings/$metric.yaml << EOF | ||
author: LSTM Baseline | ||
affiliation: EHESS, ENS, PSL Research Univerity, CNRS and Inria | ||
description: > | ||
CPC-big (trained on librispeech 960), kmeans (trained on librispeech 100), | ||
LSTM. See https://zerospeech.com/2021 for more details. | ||
open_source: true | ||
train_set: librispeech 100 and 960 | ||
gpu_budget: 60 | ||
parameters: | ||
phonetic: | ||
metric: ${metric} | ||
frame_shift: ${frame_shift} | ||
EOF | ||
|
||
for i in `basename -a $(ls -d $embeddings/*/)` | ||
do | ||
cp $embeddings/$metric.yaml $embeddings/$i/meta.yaml | ||
#zerospeech2021-evaluate -j 12 -o $results/$metric/$i --no-lexical --no-syntactic --no-semantic $DATASET_PATH $embeddings/$i | ||
#zerospeech2021-evaluate -j 12 -o $results/$metric/$i --force-cpu --no-lexical --no-syntactic --no-semantic $ORIGINAL_DATASET_PATH $embeddings/$i | ||
#zerospeech2021-evaluate -j 20 -o $results/$metric/$i --force-cpu --no-lexical --no-syntactic --no-semantic $ORIGINAL_DATASET_PATH $embeddings/$i | ||
zerospeech2021-evaluate -j 20 -o $results/$metric/$i --no-lexical --no-syntactic --no-semantic $ORIGINAL_DATASET_PATH $embeddings/$i | ||
done | ||
done | ||
|
||
for metric in ${metrics[@]} | ||
do | ||
for i in `basename -a $(ls -d $embeddings/*/)` | ||
do | ||
echo $i $metric | ||
cat $results/$metric/$i/score_phonetic.csv | ||
echo | ||
done | ||
done > $OUTPUT_DIR/combined_results.txt | ||
|
||
conda activate $SAVED_ENV |