From cb16c8e670d47f060c355c52a3009e26e4861d36 Mon Sep 17 00:00:00 2001 From: homink Date: Fri, 13 Sep 2024 02:13:05 -0700 Subject: [PATCH] Wav2Vec2Bert ASR Inference Support (#1778) * Wav2Vec2 upgrade with Conv1D options * refining scripts * refining script again * fix the formats * fix the isort format * refining the library * update based on the suggestions * update the variable name * adding unk_token removal for the Python testing * adding whitespace * update Python format * update variables * update variables * update variables * update variables * Wav2Vec2Bert ASR support * sync with the main repository * update missing parts * update missing parts2 * update the logic for make_relative_positions * update test case name * separate the asymmetric relative positions * clean empty lines * update typo * adding the version check for transformers * udpate the format * adding the version check for transformers2 * upgrade transformers 4.41 for python test * patch from https://github.com/OpenNMT/CTranslate2/issues/1711 --------- Co-authored-by: hkwon --- CMakeLists.txt | 3 + include/ctranslate2/layers/attention.h | 8 + include/ctranslate2/layers/wav2vec2bert.h | 126 +++++++++++ include/ctranslate2/models/wav2vec2bert.h | 71 ++++++ include/ctranslate2/ops/activation.h | 1 + include/ctranslate2/ops/ops.h | 1 + include/ctranslate2/ops/sigmoid.h | 21 ++ include/ctranslate2/primitives.h | 2 + python/cpp/module.cc | 1 + python/cpp/module.h | 1 + python/cpp/wav2vec2bert.cc | 124 +++++++++++ python/ctranslate2/converters/transformers.py | 125 ++++++++++- python/ctranslate2/models/__init__.py | 1 + python/ctranslate2/specs/__init__.py | 1 + python/ctranslate2/specs/attention_spec.py | 6 + python/ctranslate2/specs/common_spec.py | 1 + python/ctranslate2/specs/wav2vec2bert_spec.py | 86 +++++++ python/tests/requirements.txt | 2 +- python/tests/test_transformers.py | 80 +++++++ src/cpu/kernels.cc | 15 ++ src/cpu/kernels.h | 2 + src/cpu/primitives.cc | 9 + src/cuda/helpers.h | 8 + src/cuda/primitives.cu | 7 + src/layers/attention.cc | 54 ++++- src/layers/wav2vec2bert.cc | 210 ++++++++++++++++++ src/models/model_factory.cc | 3 + src/models/wav2vec2bert.cc | 118 ++++++++++ src/ops/activation.cc | 5 + src/ops/bias_add_gpu.cu | 5 + src/ops/dequantize_gpu.cu | 6 + src/ops/sigmoid.cc | 14 ++ tests/layers_test.cc | 10 + tests/ops_test.cc | 11 + 34 files changed, 1131 insertions(+), 7 deletions(-) create mode 100644 include/ctranslate2/layers/wav2vec2bert.h create mode 100644 include/ctranslate2/models/wav2vec2bert.h create mode 100644 include/ctranslate2/ops/sigmoid.h create mode 100644 python/cpp/wav2vec2bert.cc create mode 100644 python/ctranslate2/specs/wav2vec2bert_spec.py create mode 100644 src/layers/wav2vec2bert.cc create mode 100644 src/models/wav2vec2bert.cc create mode 100644 src/ops/sigmoid.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 52610ac89..62fc33640 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -127,6 +127,7 @@ set(SOURCES src/layers/decoder.cc src/layers/transformer.cc src/layers/wav2vec2.cc + src/layers/wav2vec2bert.cc src/layers/whisper.cc src/logging.cc src/models/language_model.cc @@ -136,6 +137,7 @@ set(SOURCES src/models/sequence_to_sequence.cc src/models/transformer.cc src/models/wav2vec2.cc + src/models/wav2vec2bert.cc src/models/whisper.cc src/ops/activation.cc src/ops/add.cc @@ -182,6 +184,7 @@ set(SOURCES src/ops/split.cc src/ops/slide.cc src/ops/sub.cc + src/ops/sigmoid.cc src/ops/swish.cc src/ops/tanh.cc src/ops/tile.cc diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index 87b21f725..5778a028c 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -10,6 +10,11 @@ namespace ctranslate2 { dim_t keys_length, dim_t max_position); + StorageView make_asymmetric_relative_positions(dim_t queries_length, + dim_t keys_length, + dim_t left_max_position, + dim_t right_max_position); + class RotaryEmbeddings; class Alibi; @@ -53,8 +58,11 @@ namespace ctranslate2 { dim_t beam_size = 1); const StorageView* _relative_attention_bias; const StorageView* _relative_position_keys; + const StorageView* _relative_asymmetric_position_keys; const StorageView* _relative_position_values; dim_t _maximum_relative_position; + dim_t _relative_left_max_position; + dim_t _relative_right_max_position; const bool _merge_time_and_head_dims; const dim_t _cache_time_dim; }; diff --git a/include/ctranslate2/layers/wav2vec2bert.h b/include/ctranslate2/layers/wav2vec2bert.h new file mode 100644 index 000000000..59ece0021 --- /dev/null +++ b/include/ctranslate2/layers/wav2vec2bert.h @@ -0,0 +1,126 @@ +#pragma once + +#include "ctranslate2/layers/attention.h" +#include "ctranslate2/layers/flash_attention.h" +#include "ctranslate2/layers/common.h" +#include "ctranslate2/layers/transformer.h" +#include "ctranslate2/padder.h" + +namespace ctranslate2 { + namespace layers { + + class EncoderLayer : public Layer { + public: + EncoderLayer(const models::Model& model, + const std::string& scope, + const bool pre_norm = true, + const ops::ActivationType activation_type = ops::ActivationType::ReLU, + const bool use_flash_attention = false); + + void operator()(const StorageView& input, StorageView& output) const; + + DataType output_type() const override { + return _final_layer_norm.output_type(); + } + + dim_t output_size() const override { + return _final_layer_norm.output_size(); + } + + const AttentionLayer& get_self_attention() const { + return *_self_attention; + } + + private: + const dim_t _num_heads; + const LayerNorm _ffn1_layer_norm; + const FeedForwardNetwork _ff1; + const LayerNorm _self_attn_layer_norm; + std::unique_ptr _self_attention; + const ops::Transpose _transpose; + const LayerNorm _layer_norm; + const Conv1D _pconv1; + const ops::Sigmoid _sigmoid; + const Conv1D _dconv; + const LayerNorm _dlayer_norm; + const ops::Swish _swish; + const Conv1D _pconv2; + const LayerNorm _ffn2_layer_norm; + const FeedForwardNetwork _ff2; + const LayerNorm _final_layer_norm; + }; + + class AdapterLayer : public Layer { + public: + AdapterLayer(const models::Model& model, + const std::string& scope, + const bool pre_norm = true, + const ops::ActivationType activation_type = ops::ActivationType::ReLU, + const bool use_flash_attention = false); + + void operator()(const StorageView& input, StorageView& output) const; + + DataType output_type() const override { + return _ffn.output_type(); + } + + dim_t output_size() const override { + return _ffn.output_size(); + } + + const AttentionLayer& get_self_attention() const { + return *_self_attention; + } + + private: + const dim_t _num_heads; + const LayerNorm _residual_layer_norm; + const ops::Transpose _transpose; + const Conv1D _residual_conv; + const ops::Sigmoid _sigmoid; + const LayerNorm _attn_layer_norm; + const Conv1D _attn_conv; + std::unique_ptr _self_attention; + const LayerNorm _ffn_layer_norm; + const FeedForwardNetwork _ffn; + }; + + class Wav2Vec2BertEncoder : public Layer { + public: + Wav2Vec2BertEncoder(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& features, StorageView& output); + + DataType output_type() const override { + return _lm_head.output_type(); + } + + dim_t output_size() const override { + return _lm_head.output_size(); + } + + dim_t input_size() const { + return 1024; + } + + bool is_encoded(const StorageView& features) const { + // Input features shape: [batch_size, input_size, input_time] + // Encoder output shape: [batch_size, input_time // 2, output_size] + // + // input_time is variable so we check that dimension 1 is different than its original value. + + return (features.rank() == 3 + && features.dim(2) == output_size() + && features.dim(1) != input_size()); + } + + private: + const LayerNorm _fp_layer_norm; + const Dense _fp_projection; + const std::vector> _encoder_layers; + const std::vector> _adapt_layers; + const Dense _lm_head; + }; + + } +} diff --git a/include/ctranslate2/models/wav2vec2bert.h b/include/ctranslate2/models/wav2vec2bert.h new file mode 100644 index 000000000..68e9a886a --- /dev/null +++ b/include/ctranslate2/models/wav2vec2bert.h @@ -0,0 +1,71 @@ +#pragma once + +#include "ctranslate2/layers/wav2vec2bert.h" +#include "ctranslate2/models/model.h" +#include "ctranslate2/replica_pool.h" + +namespace ctranslate2 { + namespace models { + + struct Wav2Vec2BertOptions { + // Maximum generation length. + size_t max_length = 448; + + // Randomly sample from the top K candidates (set 0 to sample from the full distribution). + size_t sampling_topk = 1; + + // Maximum index of the first predicted timestamp. + size_t max_initial_timestamp_index = 50; + + // Suppress blank outputs at the beginning of the sampling. + bool suppress_blank = true; + + // List of token IDs to suppress. + // -1 will suppress a default set of symbols as defined in the model config.json file. + std::vector suppress_tokens = {-1}; + }; + + + class Wav2Vec2BertModel : public Model { + public: + const Vocabulary& get_vocabulary() const; + size_t current_spec_revision() const override; + bool is_quantizable(const std::string& variable_name) const override; + bool is_linear_weight(const std::string& variable_name) const override; + std::unique_ptr clone() const override; + + bool use_global_int16_scale() const override { + return false; + } + + protected: + void initialize(ModelReader& model_reader) override; + private: + std::shared_ptr _vocabulary; + }; + + class Wav2Vec2BertReplica : public ModelReplica { + public: + static std::unique_ptr create_from_model(const Model& model); + + Wav2Vec2BertReplica(const std::shared_ptr& model); + + StorageView encode(StorageView features, const bool to_cpu); + + private: + const std::shared_ptr _model; + const std::unique_ptr _encoder; + + StorageView maybe_encode(StorageView features); + }; + + class Wav2Vec2Bert : public ReplicaPool { + public: + using ReplicaPool::ReplicaPool; + + std::future encode(const StorageView& features, const bool to_cpu); + + }; + + } +} diff --git a/include/ctranslate2/ops/activation.h b/include/ctranslate2/ops/activation.h index f500fcf9e..a9bff98cd 100644 --- a/include/ctranslate2/ops/activation.h +++ b/include/ctranslate2/ops/activation.h @@ -13,6 +13,7 @@ namespace ctranslate2 { GELU, GELUSigmoid, Tanh, + Sigmoid, }; const UnaryOp& get_activation_op(ActivationType type); diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index ed9db4265..2a735e394 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -22,6 +22,7 @@ #include "split.h" #include "squeeze.h" #include "sub.h" +#include "sigmoid.h" #include "swish.h" #include "tile.h" #include "topk.h" diff --git a/include/ctranslate2/ops/sigmoid.h b/include/ctranslate2/ops/sigmoid.h new file mode 100644 index 000000000..0921cd7ce --- /dev/null +++ b/include/ctranslate2/ops/sigmoid.h @@ -0,0 +1,21 @@ +#pragma once + +#include "op.h" + +namespace ctranslate2 { + namespace ops { + + class Sigmoid : public UnaryOp { + public: + void operator()(const StorageView& x, StorageView& y) const override; + + private: + template + void compute(const StorageView& x, StorageView& y) const { + y.resize_as(x); + primitives::sigmoid(x.data(), y.data(), x.size()); + } + }; + + } +} diff --git a/include/ctranslate2/primitives.h b/include/ctranslate2/primitives.h index bed80c8ff..571121554 100644 --- a/include/ctranslate2/primitives.h +++ b/include/ctranslate2/primitives.h @@ -181,6 +181,8 @@ namespace ctranslate2 { template static void gelu_sigmoid(const T* x, T* y, dim_t size); template + static void sigmoid(const T* x, T* y, dim_t size); + template static void swish(const T* x, T* y, dim_t size); static void compute_u8_compensation(const int8_t* b, diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 4489d5314..550aea5b2 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -87,5 +87,6 @@ PYBIND11_MODULE(_ext, m) ctranslate2::python::register_encoder(m); ctranslate2::python::register_whisper(m); ctranslate2::python::register_wav2vec2(m); + ctranslate2::python::register_wav2vec2bert(m); ctranslate2::python::register_mpi(m); } diff --git a/python/cpp/module.h b/python/cpp/module.h index 9c9a9a2ff..71d4b3b29 100644 --- a/python/cpp/module.h +++ b/python/cpp/module.h @@ -18,6 +18,7 @@ namespace ctranslate2 { void register_translator(py::module& m); void register_whisper(py::module& m); void register_wav2vec2(py::module& m); + void register_wav2vec2bert(py::module& m); void register_mpi(py::module& m); } diff --git a/python/cpp/wav2vec2bert.cc b/python/cpp/wav2vec2bert.cc new file mode 100644 index 000000000..b528f0ac1 --- /dev/null +++ b/python/cpp/wav2vec2bert.cc @@ -0,0 +1,124 @@ +#include "module.h" + +#include + +#include "replica_pool.h" + +namespace ctranslate2 { + namespace python { + + class Wav2Vec2BertWrapper : public ReplicaPoolHelper { + public: + using ReplicaPoolHelper::ReplicaPoolHelper; + + StorageView encode(const StorageView& features, const bool to_cpu) { + std::shared_lock lock(_mutex); + assert_model_is_ready(); + return _pool->encode(features, to_cpu).get(); + } + }; + + + void register_wav2vec2bert(py::module& m) { + py::class_( + m, "Wav2Vec2Bert", + R"pbdoc( + Implements the Wav2Vec2Bert speech recognition model published by Facebook. + + See Also: + https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec + )pbdoc") + + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), + py::arg("model_path"), + py::arg("device")="cpu", + py::kw_only(), + py::arg("device_index")=0, + py::arg("compute_type")="default", + py::arg("inter_threads")=1, + py::arg("intra_threads")=0, + py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, + py::arg("tensor_parallel")=false, + py::arg("files")=py::none(), + R"pbdoc( + Initializes a Wav2Vec2Bert model from a converted model. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this model on. + compute_type: Model computation type or a dictionary mapping a device name + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + inter_threads: Number of workers to allow executing multiple batches in parallel. + intra_threads: Number of OpenMP threads per worker (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, + 0 for an automatic value). When the queue is full, future requests will block + until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer + tensor_parallel: run model with tensor parallel mode + files: Load model files from the memory. This argument is a dictionary mapping + file names to file contents as file-like or bytes objects. If this is set, + :obj:`model_path` acts as an identifier for this model. + )pbdoc") + + .def_property_readonly("device", &Wav2Vec2BertWrapper::device, + "Device this model is running on.") + .def_property_readonly("device_index", &Wav2Vec2BertWrapper::device_index, + "List of device IDs where this model is running on.") + .def_property_readonly("compute_type", &Wav2Vec2BertWrapper::compute_type, + "Computation type used by the model.") + .def_property_readonly("num_workers", &Wav2Vec2BertWrapper::num_replicas, + "Number of model workers backing this instance.") + .def_property_readonly("num_queued_batches", &Wav2Vec2BertWrapper::num_queued_batches, + "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &Wav2Vec2BertWrapper::tensor_parallel, + "Run model with tensor parallel mode.") + .def_property_readonly("num_active_batches", &Wav2Vec2BertWrapper::num_active_batches, + "Number of batches waiting to be processed or currently processed.") + + .def("encode", &Wav2Vec2BertWrapper::encode, + py::arg("features"), + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Encodes the input features. + + Arguments: + features: Mel spectogram of the audio, as a float array with shape + ``[batch_size, 80, 3000]``. + to_cpu: Copy the encoder output to the CPU before returning the value. + + Returns: + The encoder output. + )pbdoc") + + .def("unload_model", &Wav2Vec2BertWrapper::unload_model, + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Unloads the model attached to this wav2vec2bert but keep enough runtime context + to quickly resume wav2vec2bert on the initial device. + + Arguments: + to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded. + )pbdoc") + + .def("load_model", &Wav2Vec2BertWrapper::load_model, + py::arg("keep_cache")=false, + py::call_guard(), + R"pbdoc( + Loads the model back to the initial device. + + Arguments: + keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists. + )pbdoc") + + .def_property_readonly("model_is_loaded", &Wav2Vec2BertWrapper::model_is_loaded, + "Whether the model is loaded on the initial device and ready to be used.") + ; + } + + } +} diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index cd8e8aef4..d90ff3569 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -23,6 +23,7 @@ model_spec, transformer_spec, wav2vec2_spec, + wav2vec2bert_spec, whisper_spec, ) @@ -356,7 +357,17 @@ def set_attention(self, spec, attention, self_attention=False): self.set_linear(spec.linear[-1], attention.out_proj) def set_common_layers(self, spec, module): - spec.scale_embeddings = module.embed_scale + import math + + if not hasattr(module, "embed_scale"): + embed_scale = ( + math.sqrt(module.config.d_model) + if module.config.scale_embedding + else 1.0 + ) + else: + embed_scale = module.embed_scale + spec.scale_embeddings = embed_scale self.set_position_encodings(spec.position_encodings, module.embed_positions) self.set_embeddings( ( @@ -1059,6 +1070,118 @@ def set_common_layers(self, spec, module): self.set_layer_norm(spec.layer_norm, module.layer_norm) +@register_loader("Wav2Vec2BertConfig") +class Wav2Vec2BertLoader(BartLoader): + @property + def architecture_name(self): + return "Wav2Vec2BertForCTC" + + def get_model_spec(self, model): + spec = wav2vec2bert_spec.Wav2Vec2BertSpec( + model.wav2vec2_bert.config.num_adapter_layers, + model.wav2vec2_bert.config.num_hidden_layers, + ) + self.set_encoder(spec.encoder, model) + return spec + + def set_config(self, config, model, tokenizer): + return + + def get_vocabulary(self, model, tokenizer): + return tokenizer.get_vocab() + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_feature_projection(self, spec, feature_projection): + self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm) + self.set_linear(spec.fp_projection, feature_projection.projection) + + def set_attention( + self, spec, attention, left_max_position=None, right_max_position=None + ): + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear(split_layers[0], attention.linear_q) + self.set_linear(split_layers[1], attention.linear_k) + self.set_linear(split_layers[2], attention.linear_v) + utils.fuse_linear(spec.linear[0], split_layers) + self.set_linear(spec.linear[-1], attention.linear_out) + if left_max_position or right_max_position: + spec.relative_asymmetric_position_keys = attention.distance_embedding.weight + spec.relative_left_max_position = np.dtype("int32").type(left_max_position) + spec.relative_right_max_position = np.dtype("int32").type( + right_max_position + ) + + def set_wav2vec2bert_encoder( + self, spec_layers, layers, left_max_position, right_max_position + ): + for slayer, layer in zip(spec_layers, layers): + self.set_layer_norm(slayer.enc_ffn1_layer_norm, layer.ffn1_layer_norm) + self.set_linear(slayer.enc_ffn1.linear_0, layer.ffn1.intermediate_dense) + self.set_linear(slayer.enc_ffn1.linear_1, layer.ffn1.output_dense) + self.set_attention( + slayer.enc_attn, layer.self_attn, left_max_position, right_max_position + ) + self.set_layer_norm(slayer.enc_attn_layer_norm, layer.self_attn_layer_norm) + self.set_layer_norm( + slayer.enc_conv_layer_norm, layer.conv_module.layer_norm + ) + self.set_conv1d( + slayer.enc_conv_pointwise_conv1, layer.conv_module.pointwise_conv1 + ) + self.set_conv1d( + slayer.enc_conv_depthwise_conv, layer.conv_module.depthwise_conv + ) + self.set_layer_norm( + slayer.enc_conv_depthwise_layer_norm, + layer.conv_module.depthwise_layer_norm, + ) + self.set_conv1d( + slayer.enc_conv_pointwise_conv2, layer.conv_module.pointwise_conv2 + ) + self.set_layer_norm(slayer.enc_ffn2_layer_norm, layer.ffn2_layer_norm) + self.set_linear(slayer.enc_ffn2.linear_0, layer.ffn2.intermediate_dense) + self.set_linear(slayer.enc_ffn2.linear_1, layer.ffn2.output_dense) + self.set_layer_norm(slayer.enc_final_layer_norm, layer.final_layer_norm) + + def set_wav2vec2bert_adapter(self, spec_layers, layers): + for slayer, layer in zip(spec_layers, layers): + self.set_layer_norm( + slayer.adpt_residual_layer_norm, layer.residual_layer_norm + ) + self.set_conv1d(slayer.adpt_residual_conv, layer.residual_conv) + self.set_layer_norm(slayer.adpt_attn_layer_norm, layer.self_attn_layer_norm) + self.set_conv1d(slayer.adpt_attn_conv, layer.self_attn_conv) + self.set_attention(slayer.adpt_attn_layer, layer.self_attn) + self.set_layer_norm(slayer.adpt_ffn_layer_norm, layer.ffn_layer_norm) + self.set_linear(slayer.adpt_ffn.linear_0, layer.ffn.intermediate_dense) + self.set_linear(slayer.adpt_ffn.linear_1, layer.ffn.output_dense) + + def set_encoder(self, spec, model): + self.set_feature_projection(spec, model.wav2vec2_bert.feature_projection) + self.set_wav2vec2bert_encoder( + spec.encoder_layers, + model.wav2vec2_bert.encoder.layers, + model.wav2vec2_bert.config.left_max_position_embeddings, + model.wav2vec2_bert.config.right_max_position_embeddings, + ) + self.set_wav2vec2bert_adapter( + spec.adapter_layers, model.wav2vec2_bert.adapter.layers + ) + self.set_linear(spec.lm_head, model.lm_head) + + def set_conv1d(self, spec, module): + spec.weight = module.weight + if module.bias is not None: + spec.bias = module.bias + + def set_layer_norm(self, spec, module): + spec.gamma = module.weight + if module.bias is not None: + spec.beta = module.bias + + @register_loader("T5Config") class T5Loader(ModelLoader): @property diff --git a/python/ctranslate2/models/__init__.py b/python/ctranslate2/models/__init__.py index aba612a5c..35a3dca37 100644 --- a/python/ctranslate2/models/__init__.py +++ b/python/ctranslate2/models/__init__.py @@ -5,6 +5,7 @@ try: from ctranslate2._ext import ( Wav2Vec2, + Wav2Vec2Bert, Whisper, WhisperGenerationResult, WhisperGenerationResultAsync, diff --git a/python/ctranslate2/specs/__init__.py b/python/ctranslate2/specs/__init__.py index 22552f5c9..b4e53fad2 100644 --- a/python/ctranslate2/specs/__init__.py +++ b/python/ctranslate2/specs/__init__.py @@ -14,4 +14,5 @@ TransformerSpec, ) from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec +from ctranslate2.specs.wav2vec2bert_spec import Wav2Vec2BertSpec from ctranslate2.specs.whisper_spec import WhisperSpec diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index 2180d779b..f49d41121 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -19,6 +19,7 @@ def __init__( self, self_attention=False, relative_position=False, + relative_asymmetric_position=False, relative_attention_bias=False, rms_norm=False, rotary_dim=None, @@ -47,6 +48,11 @@ def __init__( self.relative_attention_bias = None self.relative_attention_max_distance = None + if relative_asymmetric_position: + self.relative_asymmetric_position_keys = None + self.relative_left_max_position = None + self.relative_right_max_position = None + if original_max_position_embeddings != 0: self.original_max_position_embeddings = np.dtype("int32").type( original_max_position_embeddings diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index b1162839c..598a452d6 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -13,6 +13,7 @@ class Activation(enum.IntEnum): GELU = 3 GELUSigmoid = 4 Tanh = 5 + Sigmoid = 6 # This enum should match the C++ equivalent in include/ctranslate2/layers/common.h. diff --git a/python/ctranslate2/specs/wav2vec2bert_spec.py b/python/ctranslate2/specs/wav2vec2bert_spec.py new file mode 100644 index 000000000..5069e06e6 --- /dev/null +++ b/python/ctranslate2/specs/wav2vec2bert_spec.py @@ -0,0 +1,86 @@ +from ctranslate2.specs import attention_spec, common_spec, model_spec + + +class Wav2Vec2BertConfig(model_spec.ModelConfig): + """Configuration for the Wav2Vec2Bert model.""" + + def __init__(self): + return + + +class Wav2Vec2BertSpec(model_spec.LanguageModelSpec): + def __init__(self, num_hidden_layers, num_adapter_layers): + super().__init__() + self.encoder = Wav2Vec2BertEncoderSpec( + num_adapter_layers, + num_hidden_layers, + ) + + @property + def name(self): + return "Wav2Vec2BertSpec" + + @property + def revision(self): + return 1 + + def get_default_config(self): + return Wav2Vec2BertConfig() + + def get_vocabulary_size(self): + return self.encoder.lm_head.weight.shape[0] + + +class Wav2Vec2BertFeedForwardSpec(model_spec.LayerSpec): + def __init__(self, glu=False, rms_norm=False): + self.linear_0 = common_spec.LinearSpec() + self.linear_1 = common_spec.LinearSpec() + if glu: + self.linear_0_noact = common_spec.LinearSpec() + + +class EncoderSpec(model_spec.LayerSpec): + def __init__(self): + self.enc_ffn1_layer_norm = common_spec.LayerNormSpec() + self.enc_ffn1 = Wav2Vec2BertFeedForwardSpec() + self.enc_attn_layer_norm = common_spec.LayerNormSpec() + self.enc_attn = attention_spec.MultiHeadAttentionSpec( + self_attention=True, + relative_asymmetric_position=True, + ) + del self.enc_attn.layer_norm + self.enc_conv_layer_norm = common_spec.LayerNormSpec() + self.enc_conv_pointwise_conv1 = common_spec.Conv1DSpec() + del self.enc_conv_pointwise_conv1.bias + self.enc_conv_depthwise_conv = common_spec.Conv1DSpec() + del self.enc_conv_depthwise_conv.bias + self.enc_conv_depthwise_layer_norm = common_spec.LayerNormSpec() + self.enc_conv_pointwise_conv2 = common_spec.Conv1DSpec() + del self.enc_conv_pointwise_conv2.bias + self.enc_ffn2_layer_norm = common_spec.LayerNormSpec() + self.enc_ffn2 = Wav2Vec2BertFeedForwardSpec() + self.enc_final_layer_norm = common_spec.LayerNormSpec() + + +class AdapterSpec(model_spec.LayerSpec): + def __init__(self): + self.adpt_residual_layer_norm = common_spec.LayerNormSpec() + self.adpt_residual_conv = common_spec.Conv1DSpec() + self.adpt_attn_layer_norm = common_spec.LayerNormSpec() + self.adpt_attn_conv = common_spec.Conv1DSpec() + self.adpt_attn_layer = attention_spec.MultiHeadAttentionSpec( + self_attention=True, + relative_asymmetric_position=False, + ) + del self.adpt_attn_layer.layer_norm + self.adpt_ffn_layer_norm = common_spec.LayerNormSpec() + self.adpt_ffn = Wav2Vec2BertFeedForwardSpec() + + +class Wav2Vec2BertEncoderSpec(model_spec.LayerSpec): + def __init__(self, num_hidden_layers, num_adapter_layers): + self.fp_layer_norm = common_spec.LayerNormSpec() + self.fp_projection = common_spec.LinearSpec() + self.encoder_layers = [EncoderSpec() for _ in range(num_hidden_layers)] + self.adapter_layers = [AdapterSpec() for _ in range(num_adapter_layers)] + self.lm_head = common_spec.LinearSpec() diff --git a/python/tests/requirements.txt b/python/tests/requirements.txt index a52254f01..c5f4812a7 100644 --- a/python/tests/requirements.txt +++ b/python/tests/requirements.txt @@ -1,4 +1,4 @@ -transformers==4.35.*;platform_system=='Linux' +transformers==4.41.*;platform_system=='Linux' fairseq==0.12.2;platform_system=='Linux' or platform_system=='Darwin' OpenNMT-py==2.2.*;platform_system=='Linux' or platform_system=='Darwin' OpenNMT-tf==2.30.* diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 3c35445fa..1fed8196d 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -1023,3 +1023,83 @@ def test_transformers_wav2vec2( transcription = transcription[0].replace(processor.tokenizer.unk_token, "") assert transcription == expected_transcription[0] + + +class TestWav2Vec2Bert: + @classmethod + def teardown_class(cls): + clear_transformers_cache_in_ci() + + @test_utils.only_on_linux + @test_utils.on_available_devices + @pytest.mark.parametrize( + "model_name,expected_transcription", + [ + ( + "hf-audio/wav2vec2-bert-CV16-en", + [ + "mr quilter is the apostle of the middle classes and" + " we are glad to welcome his gospel" + ], + ), + ], + ) + def test_transformers_wav2vec2bert( + self, + tmp_dir, + device, + model_name, + expected_transcription, + ): + import torch + import transformers + + converter = ctranslate2.converters.TransformersConverter( + model_name, load_as_float16="int8" + ) + output_dir = str(tmp_dir.join("ctranslate2_model")) + output_dir = converter.convert(output_dir) + + w2v2_processor = transformers.Wav2Vec2BertProcessor.from_pretrained(model_name) + w2v2_processor.save_pretrained(output_dir + "/wav2vec2_processor") + processor = transformers.AutoProcessor.from_pretrained( + output_dir + "/wav2vec2_processor" + ) + + device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" + cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0)) + model = ctranslate2.models.Wav2Vec2Bert( + output_dir, + device=device, + device_index=[0], + compute_type="int8", + intra_threads=cpu_threads, + inter_threads=1, + ) + + speech_array = np.load( + os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy") + ) + input_values = processor( + [speech_array], + padding=True, + return_tensors="pt", + sampling_rate=16000, + ).input_features + + hidden_states = np.ascontiguousarray(input_values) + hidden_states = ctranslate2.StorageView.from_array(hidden_states) + to_cpu = model.device == "cuda" and len(model.device_index) > 1 + output = model.encode(hidden_states, to_cpu=to_cpu) + if model.device == "cuda": + logits = torch.as_tensor(output, device=model.device)[0] + else: + logits = torch.as_tensor( + np.array(output), dtype=torch.float32, device=model.device + )[0] + + predicted_ids = torch.argmax(logits, dim=-1) + transcription = processor.decode(predicted_ids, output_word_offsets=True) + transcription = transcription[0].replace(processor.tokenizer.unk_token, "") + + assert transcription == expected_transcription[0] diff --git a/src/cpu/kernels.cc b/src/cpu/kernels.cc index 2371704ec..c1f48553d 100644 --- a/src/cpu/kernels.cc +++ b/src/cpu/kernels.cc @@ -184,6 +184,13 @@ namespace ctranslate2 { } }; + struct sigmoid_func { + vec_type operator()(vec_type v) const { + using VecType = Vec; + return VecType::div(VecType::load(1.f), VecType::add(VecType::load(1.f), VecType::exp(VecType::neg(v)))); + } + }; + struct swish_func { vec_type operator()(vec_type v) const { using VecType = Vec; @@ -244,6 +251,11 @@ namespace ctranslate2 { vectorized_unary_transform(x, y, size, gelu_sigmoid_func()); } + template<> + void sigmoid(const float* x, float* y, dim_t size) { + vectorized_unary_transform(x, y, size, sigmoid_func()); + } + template<> void swish(const float* x, float* y, dim_t size) { vectorized_unary_transform(x, y, size, swish_func()); @@ -696,6 +708,9 @@ namespace ctranslate2 { case ops::ActivationType::GELUSigmoid: dequantize_gemm_output_row(c, a_scale, b_scale, bias, m, y, gelu_sigmoid_func()); break; + case ops::ActivationType::Sigmoid: + dequantize_gemm_output_row(c, a_scale, b_scale, bias, m, y, sigmoid_func()); + break; case ops::ActivationType::Swish: dequantize_gemm_output_row(c, a_scale, b_scale, bias, m, y, swish_func()); break; diff --git a/src/cpu/kernels.h b/src/cpu/kernels.h index 71d52cc67..16296fc36 100644 --- a/src/cpu/kernels.h +++ b/src/cpu/kernels.h @@ -27,6 +27,8 @@ namespace ctranslate2 { template void gelu_sigmoid(const float* x, float* y, dim_t size); template + void sigmoid(const float* x, float* y, dim_t size); + template void swish(const float* x, float* y, dim_t size); template diff --git a/src/cpu/primitives.cc b/src/cpu/primitives.cc index 0c6377bbb..5e0fd2999 100644 --- a/src/cpu/primitives.cc +++ b/src/cpu/primitives.cc @@ -313,6 +313,15 @@ namespace ctranslate2 { }); } + template<> + template<> + void primitives::sigmoid(const float* x, float* y, dim_t size) { + cpu::parallel_for(0, size, cpu::GRAIN_SIZE / 10, + [x, y](dim_t begin, dim_t end) { + CPU_ISA_DISPATCH((cpu::sigmoid(x + begin, y + begin, end - begin))); + }); + } + template<> template<> void primitives::swish(const float* x, float* y, dim_t size) { diff --git a/src/cuda/helpers.h b/src/cuda/helpers.h index a34d5d892..391fae73f 100644 --- a/src/cuda/helpers.h +++ b/src/cuda/helpers.h @@ -255,6 +255,14 @@ namespace ctranslate2 { } }; + template + struct sigmoid_func { + // Implicitly promote half to float in this function. + __device__ float operator()(float x) const { + return 1.f / (1.f + expf(-x)); + } + }; + template struct swish_func { // Implicitly promote half to float in this function. diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 149e10dbb..9915bb12c 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -218,6 +218,12 @@ namespace ctranslate2 { cuda::unary_transform(x, y, size, cuda::gelu_sigmoid_func>()); } + template<> + template + void primitives::sigmoid(const T* x, T* y, dim_t size) { + cuda::unary_transform(x, y, size, cuda::sigmoid_func>()); + } + template<> template void primitives::swish(const T* x, T* y, dim_t size) { @@ -789,6 +795,7 @@ namespace ctranslate2 { template void primitives::gelu(const T*, T*, dim_t); \ template void primitives::gelu_tanh(const T*, T*, dim_t); \ template void primitives::gelu_sigmoid(const T*, T*, dim_t); \ + template void primitives::sigmoid(const T*, T*, dim_t); \ template void primitives::swish(const T*, T*, dim_t); \ template float primitives::logsumexp(const T*, dim_t); \ template void primitives::sin(const T*, T*, dim_t); \ diff --git a/src/layers/attention.cc b/src/layers/attention.cc index a206bcd05..6ad344410 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -31,6 +31,25 @@ namespace ctranslate2 { return positions; } + StorageView make_asymmetric_relative_positions(dim_t queries_length, + dim_t keys_length, + dim_t left_max_position, + dim_t right_max_position) { + StorageView positions({queries_length, keys_length}, DataType::INT32); + auto* positions_data = positions.data(); + + const dim_t offset = keys_length - queries_length; + + for (dim_t i = 0; i < queries_length; ++i) { + auto* row = positions_data + i * keys_length; + for (dim_t j = 0; j < keys_length; ++j) { + row[j] = std::max(std::min(j - i, right_max_position), -left_max_position) + left_max_position; + } + } + + return positions; + } + static StorageView get_relative_position_bucket(bool bidirectional, dim_t query_length, dim_t key_length, @@ -163,8 +182,11 @@ namespace ctranslate2 { const StorageView& values, const StorageView* values_lengths, const StorageView* relative_position_keys, + const StorageView* relative_asymmetric_position_keys, const StorageView* relative_position_values, const StorageView* relative_attention_bias, + dim_t relative_left_max_position, + dim_t relative_right_max_position, dim_t maximum_relative_position, StorageView& output, StorageView* attention = nullptr, @@ -178,13 +200,19 @@ namespace ctranslate2 { PROFILE("dot_product_attention"); std::unique_ptr relative_positions; - if (relative_position_keys || relative_position_values) { + if (relative_position_keys || relative_position_values || relative_asymmetric_position_keys) { const dim_t query_length = queries.dim(2); const dim_t key_length = keys.dim(2); - relative_positions = std::make_unique( - make_relative_positions(query_length, - key_length, - maximum_relative_position).to(queries.device())); + if (relative_asymmetric_position_keys) + relative_positions = std::make_unique( + make_asymmetric_relative_positions(query_length, + key_length, + relative_left_max_position, + relative_right_max_position).to(queries.device())); + else relative_positions = std::make_unique( + make_relative_positions(query_length, + key_length, + maximum_relative_position).to(queries.device())); } const ops::MatMul keys_matmul(/*trans_a=*/false, /*trans_b=*/true, queries_scale); @@ -196,6 +224,12 @@ namespace ctranslate2 { keys_matmul, output); + if (relative_asymmetric_position_keys) + add_relative_representations(queries, + *relative_positions, + *relative_asymmetric_position_keys, + keys_matmul, + output); if (relative_attention_bias) { StorageView local_position_bias(output.dtype(), output.device()); @@ -269,6 +303,7 @@ namespace ctranslate2 { : AttentionLayer(model, scope, num_heads, self_attention, pre_norm, is_decoder, alibi, false) , _relative_attention_bias(model.get_variable_if_exists(scope + "/relative_attention_bias")) , _relative_position_keys(model.get_variable_if_exists(scope + "/relative_position_keys")) + , _relative_asymmetric_position_keys(model.get_variable_if_exists(scope + "/relative_asymmetric_position_keys")) , _relative_position_values(model.get_variable_if_exists(scope + "/relative_position_values")) , _merge_time_and_head_dims(_multi_query && !_relative_attention_bias @@ -278,6 +313,12 @@ namespace ctranslate2 { { if (_relative_position_keys) _maximum_relative_position = (_relative_position_keys->dim(0) - 1) / 2; + else if (_relative_asymmetric_position_keys) { + _relative_left_max_position = model.get_attribute( + scope + "/relative_left_max_position"); + _relative_right_max_position = model.get_attribute( + scope + "/relative_right_max_position"); + } else if (_relative_attention_bias) _maximum_relative_position = model.get_attribute( scope + "/relative_attention_max_distance"); @@ -432,8 +473,11 @@ namespace ctranslate2 { values_proj, values_lengths, _relative_position_keys, + _relative_asymmetric_position_keys, _relative_position_values, _relative_attention_bias, + _relative_left_max_position, + _relative_right_max_position, _maximum_relative_position, context, attention, diff --git a/src/layers/wav2vec2bert.cc b/src/layers/wav2vec2bert.cc new file mode 100644 index 000000000..7ac3620e5 --- /dev/null +++ b/src/layers/wav2vec2bert.cc @@ -0,0 +1,210 @@ +#include "ctranslate2/layers/wav2vec2bert.h" + +namespace ctranslate2 { + namespace layers { + + EncoderLayer::EncoderLayer(const models::Model& model, + const std::string& scope, + const bool pre_norm, + const ops::ActivationType activation_type, + const bool use_flash_attention) + : _ffn1_layer_norm(model, scope + "/enc_ffn1_layer_norm") + , _ff1(model, scope + "/enc_ffn1", pre_norm, activation_type) + , _self_attn_layer_norm(model, scope + "/enc_attn_layer_norm") + , _num_heads(model.get_attribute_with_default(scope + "/num_heads", 16)) + , _self_attention(!use_flash_attention ? std::unique_ptr(new MultiHeadAttention(model, + scope + "/enc_attn", + _num_heads, + /*self_attention=*/true, + pre_norm)) : std::unique_ptr(new FlashMultiHeadAttention(model, + scope + "/enc_attn", + _num_heads, + /*self_attention=*/true, + pre_norm))) + , _transpose({0, 2, 1}) + , _layer_norm(model, scope + "/enc_conv_layer_norm") + , _pconv1(model, scope + "/enc_conv_pointwise_conv1", /*stride=*/1, /*padding=*/0) + , _dconv(model, scope + "/enc_conv_depthwise_conv", /*stride=*/1, /*padding=*/0, /*dilation*/1, /*groups*/1024) + , _dlayer_norm(model, scope +"/enc_conv_depthwise_layer_norm") + , _pconv2(model, scope + "/enc_conv_pointwise_conv2", /*stride=*/1, /*padding=*/0) + , _ffn2_layer_norm(model, scope + "/enc_ffn2_layer_norm") + , _ff2(model, scope + "/enc_ffn2", pre_norm, activation_type) + , _final_layer_norm(model, scope + "/enc_final_layer_norm") { + } + + void EncoderLayer::operator()(const StorageView& input, StorageView& output) const{ + PROFILE("EncoderLayer"); + + StorageView buffer1(input.dtype(), input.device()); + StorageView buffer2(input.dtype(), input.device()); + StorageView buffer3(input.dtype(), input.device()); + StorageView residual(input.dtype(), input.device()); + StorageView m(static_cast(0.5)); + + _ffn1_layer_norm(input, buffer1); + _ff1(buffer1, buffer2); + ops::Mul()(buffer2, m, buffer1); + ops::Add()(buffer1, input, buffer2); + residual.copy_from(buffer2); + + _self_attn_layer_norm(buffer2, buffer1); + (*_self_attention)(buffer1, + buffer1, + nullptr, + buffer2, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + true, + nullptr); + ops::Add()(buffer2, residual, buffer1); + + residual.copy_from(buffer1); + _layer_norm(buffer1, buffer2); + + _transpose(buffer2, buffer1); + + _pconv1(buffer1, buffer2); + std::vector out{&buffer1, &buffer3}; + ops::Split(1, {buffer2.dim(1)/2, buffer2.dim(1)/2})(buffer2, out); + _sigmoid(buffer3, buffer3); + ops::Mul()(buffer1, buffer3, buffer2); + + StorageView buffer_zeros({buffer2.dim(0), buffer2.dim(1), 30}, + buffer2.dtype(), + buffer2.device()); + buffer_zeros.zero(); + ops::Concat(-1)({&buffer_zeros, &buffer2}, buffer1); + _dconv(buffer1, buffer2); + _transpose(buffer2, buffer1); + _dlayer_norm(buffer1, buffer2); + _transpose(buffer2, buffer1); + _swish(buffer1, buffer2); + _pconv2(buffer2, buffer1); + _transpose(buffer1, buffer2); + ops::Add()(buffer2, residual, buffer1); + + residual.copy_from(buffer1); + _ffn2_layer_norm(buffer1, buffer2); + _ff2(buffer2, buffer1); + ops::Mul()(buffer1, m, buffer2); + ops::Add()(buffer2, residual, buffer1); + + _final_layer_norm(buffer1, output); + } + + AdapterLayer::AdapterLayer(const models::Model& model, + const std::string& scope, + const bool pre_norm, + const ops::ActivationType activation_type, + const bool use_flash_attention) + : _residual_layer_norm(model, scope + "/adpt_residual_layer_norm") + , _transpose({0, 2, 1}) + , _residual_conv(model, scope + "/adpt_residual_conv", /*stride=*/2, /*padding=*/1) + , _attn_layer_norm(model, scope + "/adpt_attn_layer_norm") + , _attn_conv(model, scope + "/adpt_attn_conv", /*stride=*/2, /*padding=*/1) + , _num_heads(model.get_attribute_with_default(scope + "/num_heads", 16)) + , _self_attention(!use_flash_attention ? std::unique_ptr(new MultiHeadAttention(model, + scope + "/adpt_attn_layer", + _num_heads, + /*self_attention=*/true, + pre_norm)) : std::unique_ptr(new FlashMultiHeadAttention(model, + scope + "/adpt_attn_layer", + _num_heads, + /*self_attention=*/true, + pre_norm))) + , _ffn_layer_norm(model, scope + "/adpt_ffn_layer_norm") + , _ffn(model, scope + "/adpt_ffn", pre_norm, activation_type) { + } + + void AdapterLayer::operator()(const StorageView& input, StorageView& output) const{ + PROFILE("AdapterLayer"); + + StorageView buffer1(input.dtype(), input.device()); + StorageView buffer2(input.dtype(), input.device()); + StorageView buffer3(input.dtype(), input.device()); + StorageView residual(input.dtype(), input.device()); + std::vector out{&buffer2, &buffer3}; + + _residual_layer_norm(input, buffer1); + _transpose(buffer1, buffer2); + _residual_conv(buffer2, buffer1); + ops::Split(1, {buffer1.dim(1)/2, buffer1.dim(1)/2})(buffer1, out); + _sigmoid(buffer3, buffer3); + ops::Mul()(buffer2, buffer3, buffer1); + + _transpose(buffer1, residual); + _attn_layer_norm(input, buffer1); + _transpose(buffer1, buffer2); + _attn_conv(buffer2, buffer1); + ops::Split(1, {buffer1.dim(1)/2, buffer1.dim(1)/2})(buffer1, out); + _sigmoid(buffer3, buffer3); + ops::Mul()(buffer2, buffer3, buffer1); + + _transpose(buffer1, buffer2); + (*_self_attention)(buffer2, + buffer2, + nullptr, + buffer1, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + true, + nullptr); + ops::Add()(buffer1, residual, buffer2); + + residual.copy_from(buffer2); + _ffn_layer_norm(buffer2, buffer1); + _ffn(buffer1, buffer2); + ops::Add()(buffer1, residual, output); + } + + Wav2Vec2BertEncoder::Wav2Vec2BertEncoder(const models::Model& model, const std::string& scope) + : _fp_layer_norm(model, scope + "/fp_layer_norm") + , _fp_projection(model, scope + "/fp_projection", nullptr, true) + , _encoder_layers(build_layers_list(model, + scope + "/encoder_layers", + /*pre_norm=*/true, + ops::ActivationType::Swish, + /*use_flash_attention=*/false)) + , _adapt_layers(build_layers_list(model, + scope + "/adapter_layers", + /*pre_norm=*/true, + ops::ActivationType::ReLU, + /*use_flash_attention=*/false)) + , _lm_head(model, scope + "/lm_head", nullptr, true) { + } + + void Wav2Vec2BertEncoder::operator()(const StorageView& features, StorageView& output) { + PROFILE("Wav2Vec2BertEncoder"); + + // SAD in front-end handles the input length + if (features.rank() != 3) + throw std::invalid_argument("Expected input features to have 3 dimensions, but got " + + std::to_string(features.rank()) + + " dimension(s) instead"); + + StorageView buffer1(features.dtype(), features.device()); + StorageView buffer2(features.dtype(), features.device()); + _fp_layer_norm(features, buffer1); + _fp_projection(buffer1, buffer2); + + for (const auto& layer : _encoder_layers) { + (*layer)(buffer2, buffer1); + buffer2 = std::move(buffer1); + } + + for (const auto& layer : _adapt_layers) { + (*layer)(buffer2, buffer1); + buffer2 = std::move(buffer1); + } + + _lm_head(buffer2, output); + } + + } +} diff --git a/src/models/model_factory.cc b/src/models/model_factory.cc index 488e0b8b2..059051f5d 100644 --- a/src/models/model_factory.cc +++ b/src/models/model_factory.cc @@ -4,6 +4,7 @@ #include "ctranslate2/models/whisper.h" #include "ctranslate2/models/wav2vec2.h" +#include "ctranslate2/models/wav2vec2bert.h" #include "ctranslate2/models/transformer.h" namespace ctranslate2 { @@ -23,6 +24,8 @@ namespace ctranslate2 { register_model("WhisperSpec"); register_model("Wav2Vec2Spec"); + + register_model("Wav2Vec2BertSpec"); } std::shared_ptr create_model(const std::string& name) { diff --git a/src/models/wav2vec2bert.cc b/src/models/wav2vec2bert.cc new file mode 100644 index 000000000..b4a16a512 --- /dev/null +++ b/src/models/wav2vec2bert.cc @@ -0,0 +1,118 @@ +#include "ctranslate2/models/wav2vec2bert.h" + +#include + +#include "ctranslate2/decoding.h" + +#include "dispatch.h" +#include "dtw.h" + +#ifdef CT2_WITH_CUDA +# include "cuda/utils.h" +#endif + + +namespace ctranslate2 { + namespace models { + + const Vocabulary& Wav2Vec2BertModel::get_vocabulary() const { + return *_vocabulary; + } + + size_t Wav2Vec2BertModel::current_spec_revision() const { + return 1; + } + + void Wav2Vec2BertModel::initialize(ModelReader& model_reader) { + VocabularyInfo vocab_info; + vocab_info.unk_token = "[UNK]"; + vocab_info.bos_token = ""; + vocab_info.eos_token = ""; + + _vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info)); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + + bool Wav2Vec2BertModel::is_quantizable(const std::string& variable_name) const { + return Model::is_quantizable(variable_name); + } + + bool Wav2Vec2BertModel::is_linear_weight(const std::string& variable_name) const { + return is_quantizable(variable_name) && variable_name.find("embeddings") == std::string::npos; + } + + std::unique_ptr Wav2Vec2BertModel::clone() const { + return std::make_unique(*this); + } + + + std::unique_ptr Wav2Vec2BertReplica::create_from_model(const Model& model) { + if (!dynamic_cast(&model)) + throw std::invalid_argument("The model is not a Wav2Vec2Bert model"); + + const auto scoped_device_setter = model.get_scoped_device_setter(); + const auto model_ptr = model.shared_from_this(); + const auto concrete_model = std::static_pointer_cast(model_ptr); + return std::make_unique(concrete_model); + } + + Wav2Vec2BertReplica::Wav2Vec2BertReplica(const std::shared_ptr& model) + : ModelReplica(model) + , _model(model) + , _encoder(std::make_unique(*model, "encoder")) + { + } + + + StorageView Wav2Vec2BertReplica::encode(StorageView features, const bool to_cpu) { + PROFILE("Wav2Vec2BertReplica::encode"); + +#ifdef CT2_WITH_CUDA + const cuda::UseTrueFp16GemmInScope use_true_fp16_gemm(false); +#endif + + const auto scoped_device_setter = _model->get_scoped_device_setter(); + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + features.move_to(device, dtype); + + StorageView encoder_output(dtype, device); + (*_encoder)(features, encoder_output); + + if (to_cpu) { + if (device != Device::CPU) + encoder_output = encoder_output.to(Device::CPU); + + return encoder_output; + } + + // Ensure all operations are finished before returning the output. + synchronize_stream(device); + + return encoder_output; + } + + StorageView Wav2Vec2BertReplica::maybe_encode(StorageView features) { + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + + features.move_to(device, dtype); + + if (_encoder->is_encoded(features)) + return features; + + StorageView encoder_output(dtype, device); + (*_encoder)(features, encoder_output); + return encoder_output; + } + + std::future Wav2Vec2Bert::encode(const StorageView& features, const bool to_cpu) { + return post( + [features = features.sync_copy(), to_cpu](Wav2Vec2BertReplica& replica) mutable { + return replica.encode(std::move(features), to_cpu); + }); + } + + } +} diff --git a/src/ops/activation.cc b/src/ops/activation.cc index df2bfe6c1..5de89ffb3 100644 --- a/src/ops/activation.cc +++ b/src/ops/activation.cc @@ -2,6 +2,7 @@ #include "ctranslate2/ops/gelu.h" #include "ctranslate2/ops/relu.h" +#include "ctranslate2/ops/sigmoid.h" #include "ctranslate2/ops/swish.h" #include "ctranslate2/ops/tanh.h" @@ -26,6 +27,10 @@ namespace ctranslate2 { static const GELU gelu(GELU::Approximation::Sigmoid); return gelu; } + case ActivationType::Sigmoid: { + static const Sigmoid sigmoid; + return sigmoid; + } case ActivationType::Swish: { static const Swish swish; return swish; diff --git a/src/ops/bias_add_gpu.cu b/src/ops/bias_add_gpu.cu index 8f53bcf64..951a7671e 100644 --- a/src/ops/bias_add_gpu.cu +++ b/src/ops/bias_add_gpu.cu @@ -61,6 +61,11 @@ namespace ctranslate2 { x, b, y, depth, cuda::plus(), cuda::gelu_sigmoid_func()); break; + case ActivationType::Sigmoid: + bias_add_kernel<<>>( + x, b, y, depth, cuda::plus(), cuda::sigmoid_func()); + break; + case ActivationType::Swish: bias_add_kernel<<>>( x, b, y, depth, cuda::plus(), cuda::swish_func()); diff --git a/src/ops/dequantize_gpu.cu b/src/ops/dequantize_gpu.cu index 241b3acdb..d14ae7efb 100644 --- a/src/ops/dequantize_gpu.cu +++ b/src/ops/dequantize_gpu.cu @@ -98,6 +98,12 @@ namespace ctranslate2 { break; } + case ActivationType::Sigmoid: { + dequantize_gemm_output_kernel<<>>( + c, a_scales, b_scales, transpose_a, transpose_b, bias, cuda::sigmoid_func(), y, depth); + break; + } + case ActivationType::Swish: { dequantize_gemm_output_kernel<<>>( c, a_scales, b_scales, transpose_a, transpose_b, bias, cuda::swish_func(), y, depth); diff --git a/src/ops/sigmoid.cc b/src/ops/sigmoid.cc new file mode 100644 index 000000000..3fc006034 --- /dev/null +++ b/src/ops/sigmoid.cc @@ -0,0 +1,14 @@ +#include "ctranslate2/ops/sigmoid.h" + +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + void Sigmoid::operator()(const StorageView& x, StorageView& y) const { + PROFILE("Sigmoid"); + DEVICE_AND_FLOAT_DISPATCH("Sigmoid", x.device(), x.dtype(), (compute(x, y))); + } + + } +} diff --git a/tests/layers_test.cc b/tests/layers_test.cc index 3a8e40958..cbbaa2d72 100644 --- a/tests/layers_test.cc +++ b/tests/layers_test.cc @@ -21,6 +21,16 @@ TEST(LayerTest, MakeRelativePositions2D) { expect_storage_eq(positions, expected); } +TEST(LayerTest, MakeAsymmetricRelativePositions2D) { + const StorageView positions = layers::make_asymmetric_relative_positions(4, 4, 3, 2); + const StorageView expected({4, 4}, std::vector{ + 3, 4, 5, 5, + 2, 3, 4, 5, + 1, 2, 3, 4, + 0, 1, 2, 3}); + expect_storage_eq(positions, expected); +} + TEST_P(LayerDeviceFPTest, Alibi) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype; diff --git a/tests/ops_test.cc b/tests/ops_test.cc index 7d7b376fa..c9369fa67 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -883,6 +883,17 @@ TEST_P(OpDeviceFPTest, GELUSigmoid) { expect_storage_eq(output.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, Sigmoid) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView input({2}, std::vector{0.2, -1.3}, device); + StorageView expected({2}, std::vector{0.54983395, 0.21416503}, device); + StorageView output(dtype, device); + ops::Sigmoid()(input.to(dtype), output); + expect_storage_eq(output.to_float32(), expected, error); +} + TEST_P(OpDeviceFPTest, Swish) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype;