diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c472436e..1089106cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -134,7 +134,7 @@ set(SOURCES src/ops/bias_add.cc src/ops/bias_add_cpu.cc src/ops/concat.cc - src/ops/concat_split_cpu.cc + src/ops/concat_split_slide_cpu.cc src/ops/conv1d.cc src/ops/conv1d_cpu.cc src/ops/cos.cc @@ -168,6 +168,7 @@ set(SOURCES src/ops/softmax.cc src/ops/softmax_cpu.cc src/ops/split.cc + src/ops/slide.cc src/ops/sub.cc src/ops/swish.cc src/ops/tanh.cc @@ -506,7 +507,7 @@ if (WITH_CUDA) src/cuda/utils.cc src/ops/alibi_add_gpu.cu src/ops/bias_add_gpu.cu - src/ops/concat_split_gpu.cu + src/ops/concat_split_slide_gpu.cu src/ops/conv1d_gpu.cu src/ops/dequantize_gpu.cu src/ops/gather_gpu.cu diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index 15db7dda8..b342f4faa 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -35,7 +35,8 @@ namespace ctranslate2 { const Padder* queries_padder = nullptr, const Padder* values_padder = nullptr, bool return_normalized_attention = true, - StorageView* position_bias = nullptr) const; + StorageView* position_bias = nullptr, + dim_t offset = 0) const; bool has_positional_embeddings() const { return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi; diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index 511d8e661..61b9fae47 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -91,7 +91,8 @@ namespace ctranslate2 { const Padder* input_padder = nullptr, const Padder* memory_padder = nullptr, bool return_normalized_attention = true, - StorageView* position_bias = nullptr) const; + StorageView* position_bias = nullptr, + dim_t offset = 0) const; DataType output_type() const override { return _ff.output_type(); @@ -209,6 +210,7 @@ namespace ctranslate2 { std::vector> _alignment_heads; bool _average_alignment_heads; Dense _proj; + const dim_t _sliding_window; }; } diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index 604c43557..051c81acc 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -36,3 +36,4 @@ #include "median_filter.h" #include "rotary.h" #include "alibi_add.h" +#include "slide.h" diff --git a/include/ctranslate2/ops/slide.h b/include/ctranslate2/ops/slide.h new file mode 100644 index 000000000..176f3b0ec --- /dev/null +++ b/include/ctranslate2/ops/slide.h @@ -0,0 +1,26 @@ +#pragma once + +#include "op.h" + +namespace ctranslate2 { + namespace ops { + + class Slide : public Op { + public: + Slide(dim_t axis, const dim_t& index, const dim_t& size, bool no_copy = false); + + void operator()(const StorageView& input, StorageView& output) const; + private: + dim_t _axis; + dim_t _index; + dim_t _size; + bool _no_copy; + + void check_arguments() const; + + template + void compute(const StorageView& input, StorageView& output, const dim_t& index) const; + }; + + } +} diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index e9dbde497..7208be8a9 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -171,6 +171,8 @@ def __init__( self.alibi = alibi self.alibi_use_positive_positions = alibi_use_positive_positions self.scale_alibi = scale_alibi + if sliding_window is not None: + self.sliding_window = np.dtype("int32").type(sliding_window) if ( not relative_position and not relative_attention_bias @@ -225,6 +227,7 @@ def __init__( relative_attention_bias=relative_attention_bias, rms_norm=rms_norm, num_heads_kv=num_heads_kv, + sliding_window=sliding_window, ) self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 6480cb10d..fbb138def 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -430,7 +430,8 @@ namespace ctranslate2 { const Padder* queries_padder, const Padder* values_padder, bool return_normalized_attention, - StorageView* position_bias) const { + StorageView* position_bias, + dim_t offset) const { PROFILE("MultiHeadAttention"); const Device device = queries.device(); const DataType dtype = queries.dtype(); @@ -449,6 +450,8 @@ namespace ctranslate2 { dim_t beam_size = 1; + bool prefilling = (_sliding_window > 0 && values_lengths); + if (!_self_attention) { queries_proj = std::move(fused_proj); @@ -507,10 +510,6 @@ namespace ctranslate2 { } if (_rotary_embeddings) { - const dim_t offset = (cached_keys && !cached_keys->empty() - ? cached_keys->dim(_cache_time_dim) - : 0); - if (_merge_time_and_head_dims) { queries_proj.reshape({queries_proj.dim(0), -1, _d_model}); split_heads(queries_proj, _num_heads); @@ -536,6 +535,15 @@ namespace ctranslate2 { concat_op({&tmp, &keys_proj}, *cached_keys); tmp = std::move(*cached_values); concat_op({&tmp, &values_proj}, *cached_values); + + if (!prefilling && _sliding_window > 0 && cached_keys->shape()[2] > _sliding_window) { + // only for generation + const ops::Slide slide_op(2, 1, cached_keys->shape()[2] - 1); + slide_op(*cached_keys, tmp); + *cached_keys = std::move(tmp); + slide_op(*cached_values, tmp); + *cached_values = std::move(tmp); + } } } } @@ -564,6 +572,16 @@ namespace ctranslate2 { _alibi, position_bias); + if (prefilling && cached_keys->shape()[2] > _sliding_window) { + // set only last sliding_window tokens to cached_keys and cached_values after computing attention + const ops::Slide slide_op(2, cached_keys->shape()[2] - _sliding_window, _sliding_window); + StorageView tmp(dtype, device); + slide_op(*cached_keys, tmp); + *cached_keys = std::move(tmp); + slide_op(*cached_values, tmp); + *cached_values = std::move(tmp); + } + if (_merge_time_and_head_dims) { context.reshape(queries.shape()); if (queries_padder) diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 0aad5a33b..79959423c 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -121,7 +121,8 @@ namespace ctranslate2 { const Padder* input_padder, const Padder* memory_padder, bool return_normalized_attention, - StorageView* position_bias) const { + StorageView* position_bias, + dim_t offset) const { PROFILE("TransformerDecoderLayer"); const DataType dtype = input.dtype(); @@ -149,7 +150,8 @@ namespace ctranslate2 { input_padder, input_padder, true, - position_bias); + position_bias, + offset); if (_post_attention_layer_norm) (*_post_attention_layer_norm)(input, hidden); @@ -172,7 +174,8 @@ namespace ctranslate2 { input_padder, input_padder, true, - position_bias); + position_bias, + offset); StorageView context(dtype, device); if (_encoder_attention) { @@ -330,7 +333,8 @@ namespace ctranslate2 { ? nullptr : build_position_encoder(model, scope + "/position_encodings", _embeddings)) , _with_encoder_attention(_layers.front()->has_cross_attention()) - , _proj(model, scope + "/projection") { + , _proj(model, scope + "/projection") + , _sliding_window(model.get_attribute_with_default(scope + "/sliding_window", 0)) { dim_t alignment_layer = ( model.get_attribute_with_default(scope + "/alignment_layer", -1)); @@ -467,7 +471,13 @@ namespace ctranslate2 { (*_layernorm_embedding)(layer_in, layer_in); const dim_t batch_size = layer_in.dim(0); - const dim_t max_time = layer_in.dim(1); + dim_t max_time; + + if (_sliding_window > 0 && layer_in.dim(1) > _sliding_window) { + max_time = _sliding_window; + } else + max_time = layer_in.dim(1); + const bool allow_padding_removal = Padder::allow_padding_removal(_device, _compute_type); std::unique_ptr input_padder; @@ -479,14 +489,14 @@ namespace ctranslate2 { lengths = input_lengths.get(); } + bool multi_query = _layers.front()->get_self_attention().multi_query(); + if (lengths) { if (allow_padding_removal) { input_padder = std::make_unique(*lengths, max_time); input_padder->remove_padding(layer_in); } - const bool multi_query = _layers.front()->get_self_attention().multi_query(); - StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask( *lengths, _num_heads, @@ -531,47 +541,85 @@ namespace ctranslate2 { StorageView position_bias(dtype, device); - for (size_t l = 0; l < _layers.size(); ++l) { - StorageView* cached_self_attn_keys = nullptr; - StorageView* cached_self_attn_values = nullptr; - StorageView* cached_attn_keys = nullptr; - StorageView* cached_attn_values = nullptr; - - if (step >= 0) { - const std::string l_str = std::to_string(l); - cached_self_attn_keys = &state.at("self_keys_" + l_str); - cached_self_attn_values = &state.at("self_values_" + l_str); - if (_with_encoder_attention) { - cached_attn_keys = &state.at("memory_keys_" + l_str); - cached_attn_values = &state.at("memory_values_" + l_str); - } + std::vector layer_ins; + + while (true) { + dim_t prompt_size = layer_in.dim(1); + if (_sliding_window == 0 || prompt_size <= _sliding_window) { + layer_ins.push_back(std::move(layer_in)); + break; } + if (layer_in.dim(1) > _sliding_window) { + StorageView tmp(dtype, device); + const ops::Split split_op(1, {_sliding_window, prompt_size - _sliding_window}); + split_op(layer_in, tmp, layer_in); + layer_ins.push_back(std::move(tmp)); + } + } - std::unique_ptr heads_to_select = get_layer_alignment_heads(l, batch_size); - std::unique_ptr layer_attention; - if (attention && heads_to_select) - layer_attention = std::make_unique(dtype, device); + for (size_t i = 0; i < layer_ins.size(); ++i) { + auto layer_in_chunk = layer_ins[i]; + for (size_t l = 0; l < _layers.size(); ++l) { + StorageView* cached_self_attn_keys = nullptr; + StorageView* cached_self_attn_values = nullptr; + StorageView* cached_attn_keys = nullptr; + StorageView* cached_attn_values = nullptr; + + if (step >= 0) { + const std::string l_str = std::to_string(l); + cached_self_attn_keys = &state.at("self_keys_" + l_str); + cached_self_attn_values = &state.at("self_values_" + l_str); + if (_with_encoder_attention) { + cached_attn_keys = &state.at("memory_keys_" + l_str); + cached_attn_values = &state.at("memory_values_" + l_str); + } + } - (*_layers[l])(layer_in, - input_lengths_mask.get(), - memory, - memory_lengths_mask.get(), - cached_self_attn_keys, - cached_self_attn_values, - cached_attn_keys, - cached_attn_values, - layer_out, - layer_attention.get(), - input_padder.get(), - memory_padder.get(), - return_normalized_attention(), - &position_bias); - layer_in = std::move(layer_out); + std::unique_ptr heads_to_select = get_layer_alignment_heads(l, batch_size); + std::unique_ptr layer_attention; + if (attention && heads_to_select) + layer_attention = std::make_unique(dtype, device); + + dim_t offset = _sliding_window * i + step; + if (i > 0) { + auto max_tokens = _sliding_window + layer_in_chunk.dim(1); + StorageView tmp_lengths = StorageView(Shape{layer_in_chunk.dim(0)}, int32_t(max_tokens), device); + StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask( + tmp_lengths, + _num_heads, + max_tokens, + /*mask_future=*/true, + multi_query); + + const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk.dim(1)); + // reuse tmp_lengths + slide_lengths_op(lengths_mask, tmp_lengths); + input_lengths_mask = std::make_unique(std::move(tmp_lengths)); + } - if (layer_attention) { - alignment_heads.emplace_back(dtype, device); - ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back()); + (*_layers[l])(layer_in_chunk, + input_lengths_mask.get(), + memory, + memory_lengths_mask.get(), + cached_self_attn_keys, + cached_self_attn_values, + cached_attn_keys, + cached_attn_values, + layer_out, + layer_attention.get(), + input_padder.get(), + memory_padder.get(), + return_normalized_attention(), + &position_bias, + offset); + layer_in_chunk = std::move(layer_out); + + if (layer_attention) { + alignment_heads.emplace_back(dtype, device); + ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back()); + } } + layer_in = std::move(layer_in_chunk); } if (step == 0) { diff --git a/src/ops/concat_split_cpu.cc b/src/ops/concat_split_slide_cpu.cc similarity index 70% rename from src/ops/concat_split_cpu.cc rename to src/ops/concat_split_slide_cpu.cc index bd308912e..505581a74 100644 --- a/src/ops/concat_split_cpu.cc +++ b/src/ops/concat_split_slide_cpu.cc @@ -1,5 +1,6 @@ #include "ctranslate2/ops/concat.h" #include "ctranslate2/ops/split.h" +#include "ctranslate2/ops/slide.h" #include "cpu/parallel.h" #include "type_dispatch.h" @@ -71,13 +72,41 @@ namespace ctranslate2 { } } + template + void Slide::compute(const StorageView& input, StorageView& output, const dim_t& index) const { + const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis; + const dim_t stride_axis = input.stride(axis) == 0 ? 1 : input.stride(axis); + const dim_t step_size = input.dim(axis) * stride_axis; + const T* input_data = input.data(); + + StorageView& x = output; + T* x_data = x.data(); + + const dim_t copy_size = compute_copy_size(x, axis); + if (copy_size == 0) + return; + + const dim_t iter_size = compute_iter_size(x, axis); + + const dim_t grain_size = cpu::get_minimum_batch_copies_per_thread(copy_size); + input_data += index * stride_axis; // Read next with an offset. + cpu::parallel_for(0, iter_size, grain_size, [&](dim_t begin, dim_t end) { + for (dim_t i = begin; i < end; ++i) + primitives::copy(input_data + i * step_size, x_data + i * copy_size, copy_size); + }); + } + #define DECLARE_IMPL(T) \ template void \ Concat::compute(const std::vector& inputs, \ StorageView& output) const; \ template void \ Split::compute(const StorageView& input, \ - std::vector& outputs) const; + std::vector& outputs) const; \ + template void \ + Slide::compute(const StorageView& input, \ + StorageView& output, \ + const dim_t& index) const; DECLARE_ALL_TYPES(DECLARE_IMPL) diff --git a/src/ops/concat_split_gpu.cu b/src/ops/concat_split_slide_gpu.cu similarity index 74% rename from src/ops/concat_split_gpu.cu rename to src/ops/concat_split_slide_gpu.cu index 15a3663c5..0e24f6ebe 100644 --- a/src/ops/concat_split_gpu.cu +++ b/src/ops/concat_split_slide_gpu.cu @@ -1,5 +1,6 @@ #include "ctranslate2/ops/concat.h" #include "ctranslate2/ops/split.h" +#include "ctranslate2/ops/slide.h" #include #include @@ -163,14 +164,60 @@ namespace ctranslate2 { } } + template + void Slide::compute(const StorageView& input, StorageView& output, const dim_t& index) const { + const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis; + const dim_t input_dim = input.dim(axis); + const dim_t inner_size = input.stride(axis) == 0 ? 1 : input.stride(axis); + const dim_t inner_bytes = inner_size * sizeof (T); + const T* input_data = input.data(); + + T* output_data = output.data(); + const dim_t output_size = output.size(); + const dim_t output_bytes = output_size * sizeof (T); + if (axis == 0) { + dim_t offset = index * output.stride(axis); + primitives::copy(input_data + offset, output_data, output_size); + } + else { + const dim_t output_dim = output.dim(axis); + + if (inner_size == 1) { + auto map_ids = thrust::make_transform_iterator( + thrust::counting_iterator(0), + depth_offset_map(index, output_dim, input_dim)); + THRUST_CALL(thrust::gather, map_ids, map_ids + output_size, input_data, output_data); + } else if (inner_bytes % sizeof(uint4) == 0 && output_bytes % sizeof(uint4) == 0) { + auto map_ids = thrust::make_transform_iterator( + thrust::counting_iterator(0), + inner_dim_offset_map(index, + output_dim, + input_dim, + inner_bytes / sizeof(uint4))); + THRUST_CALL(thrust::gather, + map_ids, + map_ids + output_bytes / sizeof(uint4), + reinterpret_cast(input_data), + reinterpret_cast(output_data)); + } else { + auto map_ids = thrust::make_transform_iterator( + thrust::counting_iterator(0), + inner_dim_offset_map(index, output_dim, input_dim, inner_size)); + THRUST_CALL(thrust::gather, map_ids, map_ids + output_size, input_data, output_data); + } + } + } + #define DECLARE_IMPL(T) \ template void \ Concat::compute(const std::vector& inputs, \ StorageView& output) const; \ template void \ Split::compute(const StorageView& input, \ - std::vector& outputs) const; - + std::vector& outputs) const; \ + template void \ + Slide::compute(const StorageView& input, \ + StorageView& output, const dim_t& index) const; DECLARE_ALL_TYPES(DECLARE_IMPL) } diff --git a/src/ops/slide.cc b/src/ops/slide.cc new file mode 100644 index 000000000..31d3cede7 --- /dev/null +++ b/src/ops/slide.cc @@ -0,0 +1,47 @@ +#include "ctranslate2/ops/slide.h" + +#include + +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + Slide::Slide(dim_t axis, const dim_t& index, const dim_t& size, bool no_copy) + : _axis(axis) + , _index(index) + , _size(size) + , _no_copy(no_copy) { + check_arguments(); + } + + void Slide::operator()(const StorageView& input, StorageView& output) const { + PROFILE("Slide"); + const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis; + + if (_index < 0 || _index >= input.dim(axis)) + throw std::invalid_argument("Index or Size given is not valid"); + + dim_t offset = input.stride(0) * _index; + auto shape = input.shape(); + shape[axis] = _size; + if (_no_copy) { + TYPE_DISPATCH(input.dtype(), + output.view(const_cast(input.data() + offset), std::move(shape))); + } + else { + output.resize(std::move(shape)); + } + + if (!_no_copy) { + DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), (compute(input, output, _index))); + } + } + + void Slide::check_arguments() const { + if (_no_copy && _axis != 0) + throw std::invalid_argument("no_copy is only defined when splitting across the first dimension"); + } + + } +}