Skip to content

Commit

Permalink
Merge pull request #104 from mchinen/v1.3.0
Browse files Browse the repository at this point in the history
Update models
  • Loading branch information
yeroro authored Nov 10, 2022
2 parents a00eade + f079e8c commit d48d48c
Show file tree
Hide file tree
Showing 26 changed files with 99 additions and 86 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ If you press 'Benchmark', you should see something like the following in logcat
on a Pixel 6 Pro when running the benchmark:

```shell
lyra_benchmark: feature_extractor: max: 0.685 ms min: 0.206 ms mean: 0.219 ms stdev: 0.000 ms
lyra_benchmark: quantizer_quantize: max: 0.250 ms min: 0.076 ms mean: 0.082 ms stdev: 0.000 ms
lyra_benchmark: quantizer_decode: max: 0.152 ms min: 0.027 ms mean: 0.030 ms stdev: 0.001 ms
lyra_benchmark: model_decode: max: 0.560 ms min: 0.223 ms mean: 0.237 ms stdev: 0.000 ms
lyra_benchmark: total: max: 1.560 ms min: 0.541 ms mean: 0.569 ms stdev: 0.005 ms
lyra_benchmark: feature_extractor: max: 0.575 ms min: 0.131 ms mean: 0.139 ms stdev: 0.004 ms
lyra_benchmark: quantizer_quantize: max: 0.304 ms min: 0.105 ms mean: 0.109 ms stdev: 0.002 ms
lyra_benchmark: quantizer_decode: max: 0.103 ms min: 0.025 ms mean: 0.026 ms stdev: 0.000 ms
lyra_benchmark: model_decode: max: 0.462 ms min: 0.187 ms mean: 0.197 ms stdev: 0.001 ms
lyra_benchmark: total: max: 1.160 ms min: 0.452 ms mean: 0.473 ms stdev: 0.009 ms
```

This shows that decoding a 50Hz frame (each frame is 20 milliseconds) takes
0.569 milliseconds on average. So decoding is performed at around 35 (20/0.569)
0.473 milliseconds on average. So decoding is performed at around 42 (20/0.473)
times faster than realtime.

To build your own android app, you can either use the cc_library target outputs
Expand Down Expand Up @@ -262,9 +262,10 @@ class LyraEncoder : public LyraEncoderInterface {
The static `Create` method instantiates a `LyraEncoder` with the desired sample
rate in Hertz, number of channels and bitrate, as long as those parameters are
supported. Else it returns a nullptr. The `Create` method also needs to know if
DTX should be enabled and where the model weights are stored. It also checks
that these weights exist and are compatible with the current Lyra version.
supported (see `lyra_encoder.h` for supported parameters). Otherwise it returns
a nullptr. The `Create` method also needs to know if DTX should be enabled and
where the model weights are stored. It also checks that these weights exist and
are compatible with the current Lyra version.
Given a `LyraEncoder`, any audio stream can be compressed using the `Encode`
method. The provided span of int16-formatted samples is assumed to contain 20ms
Expand Down
10 changes: 3 additions & 7 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ git_repository(
tag = "20211102.0",
# Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved.
patches = [
"@//third_party:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff"
"@//patches:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff"
],
patch_args = [
"-p1",
Expand Down Expand Up @@ -174,12 +174,8 @@ git_repository(
)

# Check bazel version requirement, which is stricter than TensorFlow's.
load(
"@org_tensorflow//tensorflow:version_check.bzl",
"check_bazel_version_at_least",
)

check_bazel_version_at_least("3.7.2")
load("@bazel_skylib//lib:versions.bzl", "versions")
versions.check("3.7.2")

# TF WORKSPACE Loading functions
# This section uses a subset of the tensorflow WORKSPACE loading by reusing its contents.
Expand Down
2 changes: 2 additions & 0 deletions android_example/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
load("@rules_android//android:rules.bzl", "android_binary", "android_library")

# Placeholder for jni import

package(default_visibility = ["//visibility:public"])

licenses(["notice"])
Expand Down
6 changes: 4 additions & 2 deletions decoder_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ ABSL_FLAG(chromemedia::codec::PacketLossPattern, fixed_packet_loss_pattern,
"|packet_loss_rate| and |average_burst_length|.");
ABSL_FLAG(std::string, model_path, "model_coeffs",
"Path to directory containing TFLite files. For mobile this is the "
"absolute path, like '/sdcard/model_coeffs/'. For desktop this is "
"the path relative to the binary.");
"absolute path, like "
"'/data/local/tmp/model_coeffs/'. For "
"desktop "
"this is the path relative to the binary.");

int main(int argc, char** argv) {
absl::SetProgramUsageMessage(argv[0]);
Expand Down
9 changes: 6 additions & 3 deletions encoder_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ ABSL_FLAG(std::string, output_dir, "",
"name as the wav file they come from with a '.lyra' postfix. Will "
"overwrite existing files.");
ABSL_FLAG(int, bitrate, 3200,
"The bitrate in bps with which to quantize the file.");
"The bitrate in bps with which to quantize the file. The "
"bitrate options can be seen in lyra_encoder.h");
ABSL_FLAG(bool, enable_preprocessing, false,
"If enabled runs the input signal through the preprocessing "
"module before encoding.");
Expand All @@ -41,8 +42,10 @@ ABSL_FLAG(bool, enable_dtx, false,
"when noise is detected.");
ABSL_FLAG(std::string, model_path, "model_coeffs",
"Path to directory containing TFLite files. For mobile this is the "
"absolute path, like '/sdcard/model_coeffs/'. For desktop this is "
"the path relative to the binary.");
"absolute path, like "
"'/data/local/tmp/model_coeffs/'. For "
"desktop "
"this is the path relative to the binary.");

int main(int argc, char** argv) {
absl::SetProgramUsageMessage(argv[0]);
Expand Down
6 changes: 4 additions & 2 deletions lyra_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ ABSL_FLAG(int, num_cond_vectors, 2000,

ABSL_FLAG(std::string, model_path, "model_coeffs",
"Path to directory containing TFLite files. For mobile this is the "
"absolute path, like '/sdcard/model_coeffs/'. For desktop this is "
"the path relative to the binary.");
"absolute path, like "
"'/data/local/tmp/model_coeffs/'. For "
"desktop "
"this is the path relative to the binary.");

ABSL_FLAG(bool, benchmark_feature_extraction, true,
"Whether to benchmark the feature extraction.");
Expand Down
16 changes: 6 additions & 10 deletions lyra_benchmark_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ std::optional<std::vector<int16_t>> MaybeRunGenerativeModel(
return decoded;
}

// Prints stats and writes CSV for the runtime information in |timings| to file
// under /{sdcard,tmp}/benchmark/lyra/.
// Prints stats for the runtime information in |timings|. For desktop, also
// writes to CSV files under /tmp/benchmark/.
void PrintStatsAndWriteCSV(const std::vector<int64_t>& timings,
const absl::string_view title) {
constexpr absl::string_view stats_template =
Expand Down Expand Up @@ -210,19 +210,15 @@ int lyra_benchmark(const int num_cond_vectors,
const std::string model_path = GetCompleteArchitecturePath(model_base_path);

std::unique_ptr<FeatureExtractorInterface> feature_extractor =
benchmark_feature_extraction
? CreateFeatureExtractor(
kInternalSampleRateHz, kNumFeatures, num_samples_per_hop,
GetNumSamplesPerWindow(kInternalSampleRateHz), model_path)
: nullptr;
benchmark_feature_extraction ? CreateFeatureExtractor(model_path)
: nullptr;

std::unique_ptr<VectorQuantizerInterface> vector_quantizer =
benchmark_quantizer ? CreateQuantizer(kNumFeatures, model_path) : nullptr;
benchmark_quantizer ? CreateQuantizer(model_path) : nullptr;

std::unique_ptr<GenerativeModelInterface> model =
benchmark_generative_model
? CreateGenerativeModel(GetNumSamplesPerHop(kInternalSampleRateHz),
kNumFeatures, model_path)
? CreateGenerativeModel(kNumFeatures, model_path)
: nullptr;

std::vector<int64_t> feature_extractor_timings;
Expand Down
8 changes: 3 additions & 5 deletions lyra_components.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,17 @@ constexpr int kMaxNumPacketBits = 184;
} // namespace

std::unique_ptr<VectorQuantizerInterface> CreateQuantizer(
int num_output_features, const ghc::filesystem::path& model_path) {
const ghc::filesystem::path& model_path) {
return ResidualVectorQuantizer::Create(model_path);
}

std::unique_ptr<GenerativeModelInterface> CreateGenerativeModel(
int num_samples_per_hop, int num_output_features,
const ghc::filesystem::path& model_path) {
int num_output_features, const ghc::filesystem::path& model_path) {
return LyraGanModel::Create(model_path, num_output_features);
}

std::unique_ptr<FeatureExtractorInterface> CreateFeatureExtractor(
int sample_rate_hz, int num_features, int num_samples_per_hop,
int num_samples_per_window, const ghc::filesystem::path& model_path) {
const ghc::filesystem::path& model_path) {
return SoundStreamEncoder::Create(model_path);
}

Expand Down
8 changes: 3 additions & 5 deletions lyra_components.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ namespace chromemedia {
namespace codec {

std::unique_ptr<VectorQuantizerInterface> CreateQuantizer(
int num_output_features, const ghc::filesystem::path& model_path);
const ghc::filesystem::path& model_path);

std::unique_ptr<GenerativeModelInterface> CreateGenerativeModel(
int num_samples_per_hop, int num_output_features,
const ghc::filesystem::path& model_path);
int num_output_features, const ghc::filesystem::path& model_path);

std::unique_ptr<FeatureExtractorInterface> CreateFeatureExtractor(
int sample_rate_hz, int num_features, int num_samples_per_hop,
int num_samples_per_window, const ghc::filesystem::path& model_path);
const ghc::filesystem::path& model_path);

std::unique_ptr<PacketInterface> CreatePacket(int num_header_bits,
int num_quantized_bits);
Expand Down
9 changes: 5 additions & 4 deletions lyra_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,31 @@ namespace codec {
// The Lyra version is |kVersionMajor|.|kVersionMinor|.|kVersionMicro|
// The version is not used internally, but clients may use it to configure
// behavior, such as checking for version bumps that break the bitstream.
// The major version should be bumped whenever the bitstream breaks.
// The major version should be bumped for major architectural changes.
const int kVersionMajor = 1;
// |kVersionMinor| needs to be increased every time a new version requires a
// The minor version needs to be increased every time a new version requires a
// simultaneous change in code and weights or if the bit stream is modified. The
// |identifier| field needs to be set in lyra_config.textproto to match this.
const int kVersionMinor = 2;
const int kVersionMinor = 3;
// The micro version is for other things like a release of bugfixes.
const int kVersionMicro = 0;

const int kNumFeatures = 64;
const int kNumMelBins = 160;
const int kNumChannels = 1;
const int kFrameRate = 50;
const int kOverlapFactor = 2;

// LINT.IfChange
const int kNumHeaderBits = 0;
const int kFrameRate = 50;
const std::vector<int>& GetSupportedQuantizedBits() {
static const std::vector<int>* const supported_quantization_bits =
new std::vector<int>{64, 120, 184};
return *supported_quantization_bits;
}
// LINT.ThenChange(
// lyra_components.cc,
// lyra_encoder.h,
// residual_vector_quantizer.h,
// )

Expand Down
5 changes: 2 additions & 3 deletions lyra_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ std::unique_ptr<LyraDecoder> LyraDecoder::Create(
return nullptr;
}
// All internal components operate at |kInternalSampleRateHz|.
auto model =
CreateGenerativeModel(kNumSamplesPerHop, kNumFeatures, model_path);
auto model = CreateGenerativeModel(kNumFeatures, model_path);
if (model == nullptr) {
LOG(ERROR) << "New model could not be instantiated.";
return nullptr;
Expand All @@ -134,7 +133,7 @@ std::unique_ptr<LyraDecoder> LyraDecoder::Create(
LOG(ERROR) << "Could not create Noise Estimator.";
return nullptr;
}
auto vector_quantizer = CreateQuantizer(kNumFeatures, model_path);
auto vector_quantizer = CreateQuantizer(model_path);
if (vector_quantizer == nullptr) {
LOG(ERROR) << "Could not create Vector Quantizer.";
return nullptr;
Expand Down
20 changes: 6 additions & 14 deletions lyra_encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,29 +64,23 @@ std::unique_ptr<LyraEncoder> LyraEncoder::Create(
}
}

const int internal_samples_per_hop =
GetNumSamplesPerHop(kInternalSampleRateHz);
const int internal_samples_per_window =
GetNumSamplesPerWindow(kInternalSampleRateHz);
auto feature_extractor = CreateFeatureExtractor(
kInternalSampleRateHz, kNumFeatures, internal_samples_per_hop,
internal_samples_per_window, model_path);
auto feature_extractor = CreateFeatureExtractor(model_path);
if (feature_extractor == nullptr) {
LOG(ERROR) << "Could not create Features Extractor.";
return nullptr;
}

auto vector_quantizer = CreateQuantizer(kNumFeatures, model_path);
auto vector_quantizer = CreateQuantizer(model_path);
if (vector_quantizer == nullptr) {
LOG(ERROR) << "Could not create Vector Quantizer.";
return nullptr;
}

std::unique_ptr<NoiseEstimatorInterface> noise_estimator = nullptr;
if (enable_dtx) {
noise_estimator =
NoiseEstimator::Create(sample_rate_hz, internal_samples_per_hop,
internal_samples_per_window, kNumMelBins);
noise_estimator = NoiseEstimator::Create(
sample_rate_hz, GetNumSamplesPerHop(kInternalSampleRateHz),
GetNumSamplesPerWindow(kInternalSampleRateHz), kNumMelBins);
if (noise_estimator == nullptr) {
LOG(ERROR) << "Could not create Noise Estimator.";
return nullptr;
Expand Down Expand Up @@ -127,9 +121,7 @@ std::optional<std::vector<uint8_t>> LyraEncoder::Encode(
audio_for_encoding = absl::MakeConstSpan(processed);
}

const int internal_samples_per_hop =
GetNumSamplesPerHop(kInternalSampleRateHz);
if (audio_for_encoding.size() != internal_samples_per_hop) {
if (audio_for_encoding.size() != GetNumSamplesPerHop(kInternalSampleRateHz)) {
LOG(ERROR) << "The number of audio samples has to be exactly "
<< GetNumSamplesPerHop(sample_rate_hz_) << ", but is "
<< audio.size() << ".";
Expand Down
1 change: 0 additions & 1 deletion lyra_encoder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ TEST_P(LyraEncoderTest, ReceiveSamplesSucceeds) {

LyraEncoderPeer encoder_peer(
std::move(mock_resampler_), std::move(mock_feature_extractor_),

std::move(mock_noise_estimator_), std::move(mock_vector_quantizer_),
external_sample_rate_hz_, num_quantized_bits_,
/*enable_dtx=*/true);
Expand Down
9 changes: 3 additions & 6 deletions lyra_gan_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ namespace codec {

std::unique_ptr<LyraGanModel> LyraGanModel::Create(
const ghc::filesystem::path& model_path, int num_features) {
auto model = TfLiteModelWrapper::Create(model_path / "lyragan.tflite", true);
auto model =
TfLiteModelWrapper::Create(model_path / "lyragan.tflite",
/*use_xnn=*/true, /*int8_quantized=*/true);
if (model == nullptr) {
LOG(ERROR) << "Unable to create LyraGAN TFLite model wrapper.";
return nullptr;
Expand All @@ -52,11 +54,6 @@ bool LyraGanModel::RunConditioning(const std::vector<float>& features) {
absl::Span<float> input = model_->get_input_tensor<float>(0);
std::copy(features.begin(), features.end(), input.begin());
model_->Invoke();
for (int i = 1; i < model_->num_input_tensors(); ++i) {
absl::Span<float> input_state = model_->get_input_tensor<float>(i);
absl::Span<const float> output_state = model_->get_output_tensor<float>(i);
std::copy(output_state.begin(), output_state.end(), input_state.begin());
}
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion model_coeffs/lyra_config.binarypb
Original file line number Diff line number Diff line change
@@ -1 +1 @@


Binary file modified model_coeffs/lyragan.tflite
Binary file not shown.
Binary file modified model_coeffs/quantizer.tflite
Binary file not shown.
Binary file modified model_coeffs/soundstream_encoder.tflite
Binary file not shown.
3 changes: 3 additions & 0 deletions patches/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
licenses(["notice"])

package(default_visibility = ["//visibility:public"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel
index 9fceffe..e7f9d01 100644
--- a/absl/time/internal/cctz/BUILD.bazel
+++ b/absl/time/internal/cctz/BUILD.bazel
@@ -69,8 +69,5 @@ cc_library(
"include/cctz/zone_info_source.h",
],
linkopts = select({
- ":osx": [
- "-framework Foundation",
- ],
":ios": [
"-framework Foundation",
],
3 changes: 2 additions & 1 deletion residual_vector_quantizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace codec {
std::unique_ptr<ResidualVectorQuantizer> ResidualVectorQuantizer::Create(
const ghc::filesystem::path& model_path) {
auto quantizer_model =
TfLiteModelWrapper::Create(model_path / "quantizer.tflite", false);
TfLiteModelWrapper::Create(model_path / "quantizer.tflite",
/*use_xnn=*/false, /*int8_quantized=*/false);
if (quantizer_model == nullptr) {
LOG(ERROR) << "Unable to create the quantizer TfLite model wrapper.";
return nullptr;
Expand Down
2 changes: 1 addition & 1 deletion residual_vector_quantizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ TEST_P(ResidualVectorQuantizerTest, EncodeDecodeResultsInSimilarFeatures) {
auto decoded_features = quantizer_->DecodeToLossyFeatures(quantized.value());
ASSERT_TRUE(decoded_features.has_value());
EXPECT_EQ(decoded_features.value().size(), features_.size());
EXPECT_LT(FeatureDistance(decoded_features.value()), 1.1);
EXPECT_LT(FeatureDistance(decoded_features.value()), 1.11);
}

INSTANTIATE_TEST_SUITE_P(NumQuantizedBits, ResidualVectorQuantizerTest,
Expand Down
10 changes: 3 additions & 7 deletions soundstream_encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ namespace codec {

std::unique_ptr<SoundStreamEncoder> SoundStreamEncoder::Create(
const ghc::filesystem::path& model_path) {
auto model = TfLiteModelWrapper::Create(
model_path / "soundstream_encoder.tflite", true);
auto model =
TfLiteModelWrapper::Create(model_path / "soundstream_encoder.tflite",
/*use_xnn=*/true, /*int8_quantized=*/true);
if (model == nullptr) {
LOG(ERROR) << "Unable to create SoundStream encoder TFLite model wrapper.";
return nullptr;
Expand All @@ -58,11 +59,6 @@ std::optional<std::vector<float>> SoundStreamEncoder::Extract(
LOG(ERROR) << "Unable to invoke SoundStream encoder TFLite model wrapper.";
return std::nullopt;
}
for (int i = 1; i < model_->num_input_tensors(); ++i) {
absl::Span<float> input_state = model_->get_input_tensor<float>(i);
absl::Span<const float> output_state = model_->get_output_tensor<float>(i);
std::copy(output_state.begin(), output_state.end(), input_state.begin());
}
absl::Span<const float> output = model_->get_output_tensor<float>(0);
return std::vector<float>(output.begin(), output.end());
}
Expand Down
Loading

0 comments on commit d48d48c

Please sign in to comment.