-
Notifications
You must be signed in to change notification settings - Fork 309
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Wav2Vec2Bert ASR Inference Support (#1778)
* Wav2Vec2 upgrade with Conv1D options * refining scripts * refining script again * fix the formats * fix the isort format * refining the library * update based on the suggestions * update the variable name * adding unk_token removal for the Python testing * adding whitespace * update Python format * update variables * update variables * update variables * update variables * Wav2Vec2Bert ASR support * sync with the main repository * update missing parts * update missing parts2 * update the logic for make_relative_positions * update test case name * separate the asymmetric relative positions * clean empty lines * update typo * adding the version check for transformers * udpate the format * adding the version check for transformers2 * upgrade transformers 4.41 for python test * patch from #1711 --------- Co-authored-by: hkwon <[email protected]>
- Loading branch information
Showing
34 changed files
with
1,131 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#pragma once | ||
|
||
#include "ctranslate2/layers/attention.h" | ||
#include "ctranslate2/layers/flash_attention.h" | ||
#include "ctranslate2/layers/common.h" | ||
#include "ctranslate2/layers/transformer.h" | ||
#include "ctranslate2/padder.h" | ||
|
||
namespace ctranslate2 { | ||
namespace layers { | ||
|
||
class EncoderLayer : public Layer { | ||
public: | ||
EncoderLayer(const models::Model& model, | ||
const std::string& scope, | ||
const bool pre_norm = true, | ||
const ops::ActivationType activation_type = ops::ActivationType::ReLU, | ||
const bool use_flash_attention = false); | ||
|
||
void operator()(const StorageView& input, StorageView& output) const; | ||
|
||
DataType output_type() const override { | ||
return _final_layer_norm.output_type(); | ||
} | ||
|
||
dim_t output_size() const override { | ||
return _final_layer_norm.output_size(); | ||
} | ||
|
||
const AttentionLayer& get_self_attention() const { | ||
return *_self_attention; | ||
} | ||
|
||
private: | ||
const dim_t _num_heads; | ||
const LayerNorm _ffn1_layer_norm; | ||
const FeedForwardNetwork _ff1; | ||
const LayerNorm _self_attn_layer_norm; | ||
std::unique_ptr<AttentionLayer> _self_attention; | ||
const ops::Transpose _transpose; | ||
const LayerNorm _layer_norm; | ||
const Conv1D _pconv1; | ||
const ops::Sigmoid _sigmoid; | ||
const Conv1D _dconv; | ||
const LayerNorm _dlayer_norm; | ||
const ops::Swish _swish; | ||
const Conv1D _pconv2; | ||
const LayerNorm _ffn2_layer_norm; | ||
const FeedForwardNetwork _ff2; | ||
const LayerNorm _final_layer_norm; | ||
}; | ||
|
||
class AdapterLayer : public Layer { | ||
public: | ||
AdapterLayer(const models::Model& model, | ||
const std::string& scope, | ||
const bool pre_norm = true, | ||
const ops::ActivationType activation_type = ops::ActivationType::ReLU, | ||
const bool use_flash_attention = false); | ||
|
||
void operator()(const StorageView& input, StorageView& output) const; | ||
|
||
DataType output_type() const override { | ||
return _ffn.output_type(); | ||
} | ||
|
||
dim_t output_size() const override { | ||
return _ffn.output_size(); | ||
} | ||
|
||
const AttentionLayer& get_self_attention() const { | ||
return *_self_attention; | ||
} | ||
|
||
private: | ||
const dim_t _num_heads; | ||
const LayerNorm _residual_layer_norm; | ||
const ops::Transpose _transpose; | ||
const Conv1D _residual_conv; | ||
const ops::Sigmoid _sigmoid; | ||
const LayerNorm _attn_layer_norm; | ||
const Conv1D _attn_conv; | ||
std::unique_ptr<AttentionLayer> _self_attention; | ||
const LayerNorm _ffn_layer_norm; | ||
const FeedForwardNetwork _ffn; | ||
}; | ||
|
||
class Wav2Vec2BertEncoder : public Layer { | ||
public: | ||
Wav2Vec2BertEncoder(const models::Model& model, const std::string& scope); | ||
|
||
void operator()(const StorageView& features, StorageView& output); | ||
|
||
DataType output_type() const override { | ||
return _lm_head.output_type(); | ||
} | ||
|
||
dim_t output_size() const override { | ||
return _lm_head.output_size(); | ||
} | ||
|
||
dim_t input_size() const { | ||
return 1024; | ||
} | ||
|
||
bool is_encoded(const StorageView& features) const { | ||
// Input features shape: [batch_size, input_size, input_time] | ||
// Encoder output shape: [batch_size, input_time // 2, output_size] | ||
// | ||
// input_time is variable so we check that dimension 1 is different than its original value. | ||
|
||
return (features.rank() == 3 | ||
&& features.dim(2) == output_size() | ||
&& features.dim(1) != input_size()); | ||
} | ||
|
||
private: | ||
const LayerNorm _fp_layer_norm; | ||
const Dense _fp_projection; | ||
const std::vector<std::unique_ptr<const EncoderLayer>> _encoder_layers; | ||
const std::vector<std::unique_ptr<const AdapterLayer>> _adapt_layers; | ||
const Dense _lm_head; | ||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#pragma once | ||
|
||
#include "ctranslate2/layers/wav2vec2bert.h" | ||
#include "ctranslate2/models/model.h" | ||
#include "ctranslate2/replica_pool.h" | ||
|
||
namespace ctranslate2 { | ||
namespace models { | ||
|
||
struct Wav2Vec2BertOptions { | ||
// Maximum generation length. | ||
size_t max_length = 448; | ||
|
||
// Randomly sample from the top K candidates (set 0 to sample from the full distribution). | ||
size_t sampling_topk = 1; | ||
|
||
// Maximum index of the first predicted timestamp. | ||
size_t max_initial_timestamp_index = 50; | ||
|
||
// Suppress blank outputs at the beginning of the sampling. | ||
bool suppress_blank = true; | ||
|
||
// List of token IDs to suppress. | ||
// -1 will suppress a default set of symbols as defined in the model config.json file. | ||
std::vector<int> suppress_tokens = {-1}; | ||
}; | ||
|
||
|
||
class Wav2Vec2BertModel : public Model { | ||
public: | ||
const Vocabulary& get_vocabulary() const; | ||
size_t current_spec_revision() const override; | ||
bool is_quantizable(const std::string& variable_name) const override; | ||
bool is_linear_weight(const std::string& variable_name) const override; | ||
std::unique_ptr<Model> clone() const override; | ||
|
||
bool use_global_int16_scale() const override { | ||
return false; | ||
} | ||
|
||
protected: | ||
void initialize(ModelReader& model_reader) override; | ||
private: | ||
std::shared_ptr<const Vocabulary> _vocabulary; | ||
}; | ||
|
||
class Wav2Vec2BertReplica : public ModelReplica { | ||
public: | ||
static std::unique_ptr<Wav2Vec2BertReplica> create_from_model(const Model& model); | ||
|
||
Wav2Vec2BertReplica(const std::shared_ptr<const Wav2Vec2BertModel>& model); | ||
|
||
StorageView encode(StorageView features, const bool to_cpu); | ||
|
||
private: | ||
const std::shared_ptr<const Wav2Vec2BertModel> _model; | ||
const std::unique_ptr<layers::Wav2Vec2BertEncoder> _encoder; | ||
|
||
StorageView maybe_encode(StorageView features); | ||
}; | ||
|
||
class Wav2Vec2Bert : public ReplicaPool<Wav2Vec2BertReplica> { | ||
public: | ||
using ReplicaPool::ReplicaPool; | ||
|
||
std::future<StorageView> encode(const StorageView& features, const bool to_cpu); | ||
|
||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
|
||
#include "op.h" | ||
|
||
namespace ctranslate2 { | ||
namespace ops { | ||
|
||
class Sigmoid : public UnaryOp { | ||
public: | ||
void operator()(const StorageView& x, StorageView& y) const override; | ||
|
||
private: | ||
template <Device D, typename T> | ||
void compute(const StorageView& x, StorageView& y) const { | ||
y.resize_as(x); | ||
primitives<D>::sigmoid(x.data<T>(), y.data<T>(), x.size()); | ||
} | ||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.