Skip to content

Commit

Permalink
Backward compatibility for the Wav2Vec2 ASR model (#1810)
Browse files Browse the repository at this point in the history
* update description for the wav2vec2 model

* the backward compatiability support for the wav2vec2 ASR model

* dummy

* dummy push

* dummy push

* dummy push

* header update

* dummpy push

---------

Co-authored-by: hkwon <[email protected]>
  • Loading branch information
homink and hkwon authored Nov 5, 2024
1 parent 383d063 commit 5795843
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 41 deletions.
14 changes: 8 additions & 6 deletions include/ctranslate2/layers/wav2vec2.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <optional>
#include "ctranslate2/layers/transformer.h"

namespace ctranslate2 {
Expand Down Expand Up @@ -81,17 +82,18 @@ namespace ctranslate2 {
}

private:
const Wav2Vec2LayerNormConvLayer _feat_layer0;
const std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>> _feat_layers;
const LayerNorm _fp_norm;
const Dense _fp_ff;
const Wav2Vec2PosConvLayer _pos_conv_embed;
const StorageView* _upgraded_model;
std::optional<Wav2Vec2LayerNormConvLayer> _feat_layer0;
std::optional<std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>>> _feat_layers;
std::optional<LayerNorm> _fp_norm;
std::optional<Dense> _fp_ff;
std::optional<Wav2Vec2PosConvLayer> _pos_conv_embed;
const ops::Transpose _transpose;
const ops::GELU _gelu;
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
const Dense _lm_head;
std::optional<Dense> _lm_head;
};

}
Expand Down
5 changes: 3 additions & 2 deletions python/cpp/wav2vec2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ namespace ctranslate2 {
Encodes the input features.
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, 3000]``.
features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or
raw audio, as a float array with shape (followed by VAD)
``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]``
to_cpu: Copy the encoder output to the CPU before returning the value.
Returns:
Expand Down
80 changes: 47 additions & 33 deletions src/layers/wav2vec2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@ namespace ctranslate2 {
}

Wav2Vec2Encoder::Wav2Vec2Encoder(const models::Model& model, const std::string& scope)
: _feat_layer0(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0)
, _feat_layers(build_layers_list<const Wav2Vec2LayerNormConvLayer>(model,
scope + "/feat_layer",
/*stride=*/2,
/*padding=*/0))
, _fp_norm(model, scope + "/fp_layer_norm")
, _fp_ff(model, scope + "/fp_projection", nullptr, true)
, _pos_conv_embed(model, scope + "/pos_conv_embed")
: _upgraded_model(model.get_variable_if_exists(scope + "/lm_head/weight"))
, _num_heads(model.get_attribute_with_default<int32_t>(scope + "/num_heads", 8))
, _transpose({0, 2, 1})
, _layers(build_layers_list<const TransformerEncoderLayer>(model,
Expand All @@ -62,8 +55,18 @@ namespace ctranslate2 {
/*pre_norm=*/true,
ops::ActivationType::GELU))
, _output_norm(model, scope + "/layer_norm")
, _lm_head(model, scope + "/lm_head", nullptr, true)
{
if (_upgraded_model) {
_feat_layer0.emplace(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0);
_feat_layers.emplace(build_layers_list<const Wav2Vec2LayerNormConvLayer>(model,
scope + "/feat_layer",
/*stride=*/2,
/*padding=*/0));
_fp_norm.emplace(model, scope + "/fp_layer_norm");
_fp_ff.emplace(model, scope + "/fp_projection", nullptr, true);
_pos_conv_embed.emplace(model, scope + "/pos_conv_embed");
_lm_head.emplace(model, scope + "/lm_head", nullptr, true);
}
}

void Wav2Vec2Encoder::operator()(const StorageView& features, StorageView& output) {
Expand All @@ -74,33 +77,44 @@ namespace ctranslate2 {
throw std::invalid_argument("Expected input features to have 3 dimensions, but got "
+ std::to_string(features.rank())
+ " dimension(s) instead");

// Wav2Vec2FeatureExtractor------------------------------------
StorageView feat_buffer(features.dtype(), features.device());
StorageView feat_buffer2(features.dtype(), features.device());
feat_buffer = std::move(features);
_feat_layer0(feat_buffer, output);
feat_buffer = std::move(output);
for (dim_t l = 0; l < _feat_layers.size(); l++) {
(*_feat_layers[l])(feat_buffer, output);
if (l < _feat_layers.size() - 1 ) {
feat_buffer = std::move(output);
if (_upgraded_model) {
// Wav2Vec2FeatureExtractor------------------------------------
StorageView feat_buffer(features.dtype(), features.device());
StorageView feat_buffer2(features.dtype(), features.device());
feat_buffer = std::move(features);
(*_feat_layer0)(feat_buffer, output); //_feat_layer0(feat_buffer, output);
feat_buffer = std::move(output);
for (dim_t l = 0; l < _feat_layers->size(); l++) {
(*_feat_layers.value()[l])(feat_buffer, output);
if (l < _feat_layers->size() - 1 ) {
feat_buffer = std::move(output);
}
}
_transpose(output, feat_buffer);
// Wav2Vec2FeatureProjection-----------------------------------
(*_fp_norm)(feat_buffer, output); //_fp_norm(feat_buffer, output);
(*_fp_ff)(output, feat_buffer); //_fp_ff(output, feat_buffer);
// Wav2Vec2PositionalConvEmbedding-----------------------------
(*_pos_conv_embed)(feat_buffer, feat_buffer2); //_pos_conv_embed(feat_buffer, feat_buffer2);
// Wav2Vec2EncoderLayerStableLayerNorm-------------------------
for (const auto& layer : _layers) {
(*layer)(feat_buffer2, nullptr, feat_buffer);
feat_buffer2 = std::move(feat_buffer);
}
_output_norm(feat_buffer2, feat_buffer);

(*_lm_head)(feat_buffer, output); //_lm_head(feat_buffer, output);
}
_transpose(output, feat_buffer);
// Wav2Vec2FeatureProjection-----------------------------------
_fp_norm(feat_buffer, output);
_fp_ff(output, feat_buffer);
// Wav2Vec2PositionalConvEmbedding-----------------------------
_pos_conv_embed(feat_buffer, feat_buffer2);
// Wav2Vec2EncoderLayerStableLayerNorm-------------------------
for (const auto& layer : _layers) {
(*layer)(feat_buffer2, nullptr, feat_buffer);
feat_buffer2 = std::move(feat_buffer);
}
_output_norm(feat_buffer2, feat_buffer);
else { // backward compatibility for the previous converted model
StorageView input(output_type(), features.device());
input = features;
for (const auto& layer : _layers) {
(*layer)(input, nullptr, output);
input = std::move(output);
}

_lm_head(feat_buffer, output);
_output_norm(input, output);
}
}

}
Expand Down

0 comments on commit 5795843

Please sign in to comment.