Skip to content

Commit

Permalink
Wav2Vec2Bert ASR Inference Support (#1778)
Browse files Browse the repository at this point in the history
* Wav2Vec2 upgrade with Conv1D options

* refining scripts

* refining script again

* fix the formats

* fix the isort format

* refining the library

* update based on the suggestions

* update the variable name

* adding unk_token removal for the Python testing

* adding whitespace

* update Python format

* update variables

* update variables

* update variables

* update variables

* Wav2Vec2Bert ASR support

* sync with the main repository

* update missing parts

* update missing parts2

* update the logic for make_relative_positions

* update test case name

* separate the asymmetric relative positions

* clean empty lines

* update typo

* adding the version check for transformers

* udpate the format

* adding the version check for transformers2

* upgrade transformers 4.41 for python test

* patch from #1711

---------

Co-authored-by: hkwon <[email protected]>
  • Loading branch information
homink and hkwon authored Sep 13, 2024
1 parent 8f4d134 commit cb16c8e
Show file tree
Hide file tree
Showing 34 changed files with 1,131 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ set(SOURCES
src/layers/decoder.cc
src/layers/transformer.cc
src/layers/wav2vec2.cc
src/layers/wav2vec2bert.cc
src/layers/whisper.cc
src/logging.cc
src/models/language_model.cc
Expand All @@ -136,6 +137,7 @@ set(SOURCES
src/models/sequence_to_sequence.cc
src/models/transformer.cc
src/models/wav2vec2.cc
src/models/wav2vec2bert.cc
src/models/whisper.cc
src/ops/activation.cc
src/ops/add.cc
Expand Down Expand Up @@ -182,6 +184,7 @@ set(SOURCES
src/ops/split.cc
src/ops/slide.cc
src/ops/sub.cc
src/ops/sigmoid.cc
src/ops/swish.cc
src/ops/tanh.cc
src/ops/tile.cc
Expand Down
8 changes: 8 additions & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ namespace ctranslate2 {
dim_t keys_length,
dim_t max_position);

StorageView make_asymmetric_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t left_max_position,
dim_t right_max_position);

class RotaryEmbeddings;
class Alibi;

Expand Down Expand Up @@ -53,8 +58,11 @@ namespace ctranslate2 {
dim_t beam_size = 1);
const StorageView* _relative_attention_bias;
const StorageView* _relative_position_keys;
const StorageView* _relative_asymmetric_position_keys;
const StorageView* _relative_position_values;
dim_t _maximum_relative_position;
dim_t _relative_left_max_position;
dim_t _relative_right_max_position;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
};
Expand Down
126 changes: 126 additions & 0 deletions include/ctranslate2/layers/wav2vec2bert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#pragma once

#include "ctranslate2/layers/attention.h"
#include "ctranslate2/layers/flash_attention.h"
#include "ctranslate2/layers/common.h"
#include "ctranslate2/layers/transformer.h"
#include "ctranslate2/padder.h"

namespace ctranslate2 {
namespace layers {

class EncoderLayer : public Layer {
public:
EncoderLayer(const models::Model& model,
const std::string& scope,
const bool pre_norm = true,
const ops::ActivationType activation_type = ops::ActivationType::ReLU,
const bool use_flash_attention = false);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _final_layer_norm.output_type();
}

dim_t output_size() const override {
return _final_layer_norm.output_size();
}

const AttentionLayer& get_self_attention() const {
return *_self_attention;
}

private:
const dim_t _num_heads;
const LayerNorm _ffn1_layer_norm;
const FeedForwardNetwork _ff1;
const LayerNorm _self_attn_layer_norm;
std::unique_ptr<AttentionLayer> _self_attention;
const ops::Transpose _transpose;
const LayerNorm _layer_norm;
const Conv1D _pconv1;
const ops::Sigmoid _sigmoid;
const Conv1D _dconv;
const LayerNorm _dlayer_norm;
const ops::Swish _swish;
const Conv1D _pconv2;
const LayerNorm _ffn2_layer_norm;
const FeedForwardNetwork _ff2;
const LayerNorm _final_layer_norm;
};

class AdapterLayer : public Layer {
public:
AdapterLayer(const models::Model& model,
const std::string& scope,
const bool pre_norm = true,
const ops::ActivationType activation_type = ops::ActivationType::ReLU,
const bool use_flash_attention = false);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _ffn.output_type();
}

dim_t output_size() const override {
return _ffn.output_size();
}

const AttentionLayer& get_self_attention() const {
return *_self_attention;
}

private:
const dim_t _num_heads;
const LayerNorm _residual_layer_norm;
const ops::Transpose _transpose;
const Conv1D _residual_conv;
const ops::Sigmoid _sigmoid;
const LayerNorm _attn_layer_norm;
const Conv1D _attn_conv;
std::unique_ptr<AttentionLayer> _self_attention;
const LayerNorm _ffn_layer_norm;
const FeedForwardNetwork _ffn;
};

class Wav2Vec2BertEncoder : public Layer {
public:
Wav2Vec2BertEncoder(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
return _lm_head.output_type();
}

dim_t output_size() const override {
return _lm_head.output_size();
}

dim_t input_size() const {
return 1024;
}

bool is_encoded(const StorageView& features) const {
// Input features shape: [batch_size, input_size, input_time]
// Encoder output shape: [batch_size, input_time // 2, output_size]
//
// input_time is variable so we check that dimension 1 is different than its original value.

return (features.rank() == 3
&& features.dim(2) == output_size()
&& features.dim(1) != input_size());
}

private:
const LayerNorm _fp_layer_norm;
const Dense _fp_projection;
const std::vector<std::unique_ptr<const EncoderLayer>> _encoder_layers;
const std::vector<std::unique_ptr<const AdapterLayer>> _adapt_layers;
const Dense _lm_head;
};

}
}
71 changes: 71 additions & 0 deletions include/ctranslate2/models/wav2vec2bert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once

#include "ctranslate2/layers/wav2vec2bert.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"

namespace ctranslate2 {
namespace models {

struct Wav2Vec2BertOptions {
// Maximum generation length.
size_t max_length = 448;

// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
size_t sampling_topk = 1;

// Maximum index of the first predicted timestamp.
size_t max_initial_timestamp_index = 50;

// Suppress blank outputs at the beginning of the sampling.
bool suppress_blank = true;

// List of token IDs to suppress.
// -1 will suppress a default set of symbols as defined in the model config.json file.
std::vector<int> suppress_tokens = {-1};
};


class Wav2Vec2BertModel : public Model {
public:
const Vocabulary& get_vocabulary() const;
size_t current_spec_revision() const override;
bool is_quantizable(const std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
std::unique_ptr<Model> clone() const override;

bool use_global_int16_scale() const override {
return false;
}

protected:
void initialize(ModelReader& model_reader) override;
private:
std::shared_ptr<const Vocabulary> _vocabulary;
};

class Wav2Vec2BertReplica : public ModelReplica {
public:
static std::unique_ptr<Wav2Vec2BertReplica> create_from_model(const Model& model);

Wav2Vec2BertReplica(const std::shared_ptr<const Wav2Vec2BertModel>& model);

StorageView encode(StorageView features, const bool to_cpu);

private:
const std::shared_ptr<const Wav2Vec2BertModel> _model;
const std::unique_ptr<layers::Wav2Vec2BertEncoder> _encoder;

StorageView maybe_encode(StorageView features);
};

class Wav2Vec2Bert : public ReplicaPool<Wav2Vec2BertReplica> {
public:
using ReplicaPool::ReplicaPool;

std::future<StorageView> encode(const StorageView& features, const bool to_cpu);

};

}
}
1 change: 1 addition & 0 deletions include/ctranslate2/ops/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace ctranslate2 {
GELU,
GELUSigmoid,
Tanh,
Sigmoid,
};

const UnaryOp& get_activation_op(ActivationType type);
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "split.h"
#include "squeeze.h"
#include "sub.h"
#include "sigmoid.h"
#include "swish.h"
#include "tile.h"
#include "topk.h"
Expand Down
21 changes: 21 additions & 0 deletions include/ctranslate2/ops/sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class Sigmoid : public UnaryOp {
public:
void operator()(const StorageView& x, StorageView& y) const override;

private:
template <Device D, typename T>
void compute(const StorageView& x, StorageView& y) const {
y.resize_as(x);
primitives<D>::sigmoid(x.data<T>(), y.data<T>(), x.size());
}
};

}
}
2 changes: 2 additions & 0 deletions include/ctranslate2/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ namespace ctranslate2 {
template <typename T>
static void gelu_sigmoid(const T* x, T* y, dim_t size);
template <typename T>
static void sigmoid(const T* x, T* y, dim_t size);
template <typename T>
static void swish(const T* x, T* y, dim_t size);

static void compute_u8_compensation(const int8_t* b,
Expand Down
1 change: 1 addition & 0 deletions python/cpp/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,6 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_encoder(m);
ctranslate2::python::register_whisper(m);
ctranslate2::python::register_wav2vec2(m);
ctranslate2::python::register_wav2vec2bert(m);
ctranslate2::python::register_mpi(m);
}
1 change: 1 addition & 0 deletions python/cpp/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace ctranslate2 {
void register_translator(py::module& m);
void register_whisper(py::module& m);
void register_wav2vec2(py::module& m);
void register_wav2vec2bert(py::module& m);
void register_mpi(py::module& m);

}
Expand Down
Loading

0 comments on commit cb16c8e

Please sign in to comment.