From fa1f7b55dcf061cb23abc94da388dc6ce58f266b Mon Sep 17 00:00:00 2001 From: Masao-Someki Date: Mon, 17 Jun 2024 00:01:46 +0900 Subject: [PATCH] Add Export/inference without frontend module --- espnet_onnx/asr/model/encoders/encoder.py | 12 ++++++-- espnet_onnx/export/asr/get_config.py | 4 +++ tests/integration_tests/test_asr.py | 35 +++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/espnet_onnx/asr/model/encoders/encoder.py b/espnet_onnx/asr/model/encoders/encoder.py index 7a94a17..bf79350 100644 --- a/espnet_onnx/asr/model/encoders/encoder.py +++ b/espnet_onnx/asr/model/encoders/encoder.py @@ -29,7 +29,11 @@ def __init__( self.config.model_path, providers=providers ) - self.frontend = Frontend(self.config.frontend, providers, use_quantized) + if self.config.frontend.frontend_type is None: + self.frontend = None + else: + self.frontend = Frontend(self.config.frontend, providers, use_quantized) + if self.config.do_normalize: if self.config.normalize.type == "gmvn": self.normalize = GlobalMVN(self.config.normalize) @@ -51,7 +55,11 @@ def __call__( speech_lengths: (Batch, ) """ # 1. Extract feature - feats, feat_length = self.frontend(speech, speech_length) + if self.frontend is not None: + feats, feat_length = self.frontend(speech, speech_length) + else: + feats = speech + feat_length = speech_length # 2. normalize with global MVN if self.config.do_normalize: diff --git a/espnet_onnx/export/asr/get_config.py b/espnet_onnx/export/asr/get_config.py index b7338f2..b0e26ab 100644 --- a/espnet_onnx/export/asr/get_config.py +++ b/espnet_onnx/export/asr/get_config.py @@ -90,6 +90,10 @@ def get_frontend_config(asr_frontend_model, frontend=None, **kwargs): frontend_config = frontend.get_model_config(**kwargs) elif isinstance(asr_frontend_model, DefaultFrontend): frontend_config = get_default_frontend(asr_frontend_model) + elif asr_frontend_model is None: + frontend_config = { + "frontend_type": None, + } else: raise ValueError("Currently only s3prl is supported.") diff --git a/tests/integration_tests/test_asr.py b/tests/integration_tests/test_asr.py index 9ff44e5..b1742b3 100644 --- a/tests/integration_tests/test_asr.py +++ b/tests/integration_tests/test_asr.py @@ -2,6 +2,7 @@ import librosa import pytest +import numpy as np from espnet_onnx import Speech2Text as onnxSpeech2Text @@ -21,6 +22,10 @@ "original/conformer" ] +asr_without_frontend = [ + "original/conformer" +] + @pytest.mark.parametrize("asr_config_names", asr_config_names) def test_asr(asr_config_names, load_config, wav_files, model_export): @@ -78,3 +83,33 @@ def test_asr_custom_dir(asr_config_names, load_config, wav_files, custom_dir_mod onnx_output = onnx_model(y)[0] assert espnet_output[2] == onnx_output[2] + + +@pytest.mark.parametrize("asr_config_names", asr_without_frontend) +def test_asr_without_frontend(asr_config_names, load_config, custom_dir_model_export): + config = load_config(asr_config_names, model_type="integration") + config.tag_name = "test/integration/" + config.tag_name + + # build ASR model + espnet_model = build_model(config.model_config) + espnet_model.asr_model.frontend = None + + # test export + export_model(custom_dir_model_export, copy.deepcopy(espnet_model), config) + file_paths = check_models( + custom_dir_model_export.cache_dir, + config.tag_name, + config.check_export, + False, + ) + check_optimize(config, file_paths) + eval_model(espnet_model) + + # parity check with espnet model + onnx_model = onnxSpeech2Text(config.tag_name, cache_dir=custom_dir_model_export.cache_dir) + for feat_length in [100, 200]: + y = np.random.rand(feat_length, 80) + espnet_output = espnet_model(y)[0] + onnx_output = onnx_model(y)[0] + + assert espnet_output[2] == onnx_output[2]