Skip to content

Commit

Permalink
Sliding window + chunking input for mistral model (#1524)
Browse files Browse the repository at this point in the history
* sliding window + chunking input
---------
Co-authored-by: thucpham <[email protected]>
  • Loading branch information
minhthuc2502 authored Nov 17, 2023
1 parent b64bb08 commit 120746e
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 55 deletions.
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ set(SOURCES
src/ops/bias_add.cc
src/ops/bias_add_cpu.cc
src/ops/concat.cc
src/ops/concat_split_cpu.cc
src/ops/concat_split_slide_cpu.cc
src/ops/conv1d.cc
src/ops/conv1d_cpu.cc
src/ops/cos.cc
Expand Down Expand Up @@ -168,6 +168,7 @@ set(SOURCES
src/ops/softmax.cc
src/ops/softmax_cpu.cc
src/ops/split.cc
src/ops/slide.cc
src/ops/sub.cc
src/ops/swish.cc
src/ops/tanh.cc
Expand Down Expand Up @@ -506,7 +507,7 @@ if (WITH_CUDA)
src/cuda/utils.cc
src/ops/alibi_add_gpu.cu
src/ops/bias_add_gpu.cu
src/ops/concat_split_gpu.cu
src/ops/concat_split_slide_gpu.cu
src/ops/conv1d_gpu.cu
src/ops/dequantize_gpu.cu
src/ops/gather_gpu.cu
Expand Down
3 changes: 2 additions & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace ctranslate2 {
const Padder* queries_padder = nullptr,
const Padder* values_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr) const;
StorageView* position_bias = nullptr,
dim_t offset = 0) const;

bool has_positional_embeddings() const {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
Expand Down
4 changes: 3 additions & 1 deletion include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ namespace ctranslate2 {
const Padder* input_padder = nullptr,
const Padder* memory_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr) const;
StorageView* position_bias = nullptr,
dim_t offset = 0) const;

DataType output_type() const override {
return _ff.output_type();
Expand Down Expand Up @@ -209,6 +210,7 @@ namespace ctranslate2 {
std::vector<std::vector<dim_t>> _alignment_heads;
bool _average_alignment_heads;
Dense _proj;
const dim_t _sliding_window;
};

}
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
#include "median_filter.h"
#include "rotary.h"
#include "alibi_add.h"
#include "slide.h"
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/slide.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class Slide : public Op {
public:
Slide(dim_t axis, const dim_t& index, const dim_t& size, bool no_copy = false);

void operator()(const StorageView& input, StorageView& output) const;
private:
dim_t _axis;
dim_t _index;
dim_t _size;
bool _no_copy;

void check_arguments() const;

template <Device D, typename T>
void compute(const StorageView& input, StorageView& output, const dim_t& index) const;
};

}
}
3 changes: 3 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def __init__(
self.alibi = alibi
self.alibi_use_positive_positions = alibi_use_positive_positions
self.scale_alibi = scale_alibi
if sliding_window is not None:
self.sliding_window = np.dtype("int32").type(sliding_window)
if (
not relative_position
and not relative_attention_bias
Expand Down Expand Up @@ -225,6 +227,7 @@ def __init__(
relative_attention_bias=relative_attention_bias,
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)

Expand Down
28 changes: 23 additions & 5 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ namespace ctranslate2 {
const Padder* queries_padder,
const Padder* values_padder,
bool return_normalized_attention,
StorageView* position_bias) const {
StorageView* position_bias,
dim_t offset) const {
PROFILE("MultiHeadAttention");
const Device device = queries.device();
const DataType dtype = queries.dtype();
Expand All @@ -449,6 +450,8 @@ namespace ctranslate2 {

dim_t beam_size = 1;

bool prefilling = (_sliding_window > 0 && values_lengths);

if (!_self_attention) {
queries_proj = std::move(fused_proj);

Expand Down Expand Up @@ -507,10 +510,6 @@ namespace ctranslate2 {
}

if (_rotary_embeddings) {
const dim_t offset = (cached_keys && !cached_keys->empty()
? cached_keys->dim(_cache_time_dim)
: 0);

if (_merge_time_and_head_dims) {
queries_proj.reshape({queries_proj.dim(0), -1, _d_model});
split_heads(queries_proj, _num_heads);
Expand All @@ -536,6 +535,15 @@ namespace ctranslate2 {
concat_op({&tmp, &keys_proj}, *cached_keys);
tmp = std::move(*cached_values);
concat_op({&tmp, &values_proj}, *cached_values);

if (!prefilling && _sliding_window > 0 && cached_keys->shape()[2] > _sliding_window) {
// only for generation
const ops::Slide slide_op(2, 1, cached_keys->shape()[2] - 1);
slide_op(*cached_keys, tmp);
*cached_keys = std::move(tmp);
slide_op(*cached_values, tmp);
*cached_values = std::move(tmp);
}
}
}
}
Expand Down Expand Up @@ -564,6 +572,16 @@ namespace ctranslate2 {
_alibi,
position_bias);

if (prefilling && cached_keys->shape()[2] > _sliding_window) {
// set only last sliding_window tokens to cached_keys and cached_values after computing attention
const ops::Slide slide_op(2, cached_keys->shape()[2] - _sliding_window, _sliding_window);
StorageView tmp(dtype, device);
slide_op(*cached_keys, tmp);
*cached_keys = std::move(tmp);
slide_op(*cached_values, tmp);
*cached_values = std::move(tmp);
}

if (_merge_time_and_head_dims) {
context.reshape(queries.shape());
if (queries_padder)
Expand Down
134 changes: 91 additions & 43 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ namespace ctranslate2 {
const Padder* input_padder,
const Padder* memory_padder,
bool return_normalized_attention,
StorageView* position_bias) const {
StorageView* position_bias,
dim_t offset) const {
PROFILE("TransformerDecoderLayer");

const DataType dtype = input.dtype();
Expand Down Expand Up @@ -149,7 +150,8 @@ namespace ctranslate2 {
input_padder,
input_padder,
true,
position_bias);
position_bias,
offset);

if (_post_attention_layer_norm)
(*_post_attention_layer_norm)(input, hidden);
Expand All @@ -172,7 +174,8 @@ namespace ctranslate2 {
input_padder,
input_padder,
true,
position_bias);
position_bias,
offset);

StorageView context(dtype, device);
if (_encoder_attention) {
Expand Down Expand Up @@ -330,7 +333,8 @@ namespace ctranslate2 {
? nullptr
: build_position_encoder(model, scope + "/position_encodings", _embeddings))
, _with_encoder_attention(_layers.front()->has_cross_attention())
, _proj(model, scope + "/projection") {
, _proj(model, scope + "/projection")
, _sliding_window(model.get_attribute_with_default<int32_t>(scope + "/sliding_window", 0)) {

dim_t alignment_layer = (
model.get_attribute_with_default<int32_t>(scope + "/alignment_layer", -1));
Expand Down Expand Up @@ -467,7 +471,13 @@ namespace ctranslate2 {
(*_layernorm_embedding)(layer_in, layer_in);

const dim_t batch_size = layer_in.dim(0);
const dim_t max_time = layer_in.dim(1);
dim_t max_time;

if (_sliding_window > 0 && layer_in.dim(1) > _sliding_window) {
max_time = _sliding_window;
} else
max_time = layer_in.dim(1);

const bool allow_padding_removal = Padder::allow_padding_removal(_device, _compute_type);

std::unique_ptr<const Padder> input_padder;
Expand All @@ -479,14 +489,14 @@ namespace ctranslate2 {
lengths = input_lengths.get();
}

bool multi_query = _layers.front()->get_self_attention().multi_query();

if (lengths) {
if (allow_padding_removal) {
input_padder = std::make_unique<Padder>(*lengths, max_time);
input_padder->remove_padding(layer_in);
}

const bool multi_query = _layers.front()->get_self_attention().multi_query();

StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
*lengths,
_num_heads,
Expand Down Expand Up @@ -531,47 +541,85 @@ namespace ctranslate2 {

StorageView position_bias(dtype, device);

for (size_t l = 0; l < _layers.size(); ++l) {
StorageView* cached_self_attn_keys = nullptr;
StorageView* cached_self_attn_values = nullptr;
StorageView* cached_attn_keys = nullptr;
StorageView* cached_attn_values = nullptr;

if (step >= 0) {
const std::string l_str = std::to_string(l);
cached_self_attn_keys = &state.at("self_keys_" + l_str);
cached_self_attn_values = &state.at("self_values_" + l_str);
if (_with_encoder_attention) {
cached_attn_keys = &state.at("memory_keys_" + l_str);
cached_attn_values = &state.at("memory_values_" + l_str);
}
std::vector<StorageView> layer_ins;

while (true) {
dim_t prompt_size = layer_in.dim(1);
if (_sliding_window == 0 || prompt_size <= _sliding_window) {
layer_ins.push_back(std::move(layer_in));
break;
}
if (layer_in.dim(1) > _sliding_window) {
StorageView tmp(dtype, device);
const ops::Split split_op(1, {_sliding_window, prompt_size - _sliding_window});
split_op(layer_in, tmp, layer_in);
layer_ins.push_back(std::move(tmp));
}
}

std::unique_ptr<StorageView> heads_to_select = get_layer_alignment_heads(l, batch_size);
std::unique_ptr<StorageView> layer_attention;
if (attention && heads_to_select)
layer_attention = std::make_unique<StorageView>(dtype, device);
for (size_t i = 0; i < layer_ins.size(); ++i) {
auto layer_in_chunk = layer_ins[i];
for (size_t l = 0; l < _layers.size(); ++l) {
StorageView* cached_self_attn_keys = nullptr;
StorageView* cached_self_attn_values = nullptr;
StorageView* cached_attn_keys = nullptr;
StorageView* cached_attn_values = nullptr;

if (step >= 0) {
const std::string l_str = std::to_string(l);
cached_self_attn_keys = &state.at("self_keys_" + l_str);
cached_self_attn_values = &state.at("self_values_" + l_str);
if (_with_encoder_attention) {
cached_attn_keys = &state.at("memory_keys_" + l_str);
cached_attn_values = &state.at("memory_values_" + l_str);
}
}

(*_layers[l])(layer_in,
input_lengths_mask.get(),
memory,
memory_lengths_mask.get(),
cached_self_attn_keys,
cached_self_attn_values,
cached_attn_keys,
cached_attn_values,
layer_out,
layer_attention.get(),
input_padder.get(),
memory_padder.get(),
return_normalized_attention(),
&position_bias);
layer_in = std::move(layer_out);
std::unique_ptr<StorageView> heads_to_select = get_layer_alignment_heads(l, batch_size);
std::unique_ptr<StorageView> layer_attention;
if (attention && heads_to_select)
layer_attention = std::make_unique<StorageView>(dtype, device);

dim_t offset = _sliding_window * i + step;
if (i > 0) {
auto max_tokens = _sliding_window + layer_in_chunk.dim(1);
StorageView tmp_lengths = StorageView(Shape{layer_in_chunk.dim(0)}, int32_t(max_tokens), device);
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
tmp_lengths,
_num_heads,
max_tokens,
/*mask_future=*/true,
multi_query);

const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk.dim(1));
// reuse tmp_lengths
slide_lengths_op(lengths_mask, tmp_lengths);
input_lengths_mask = std::make_unique<StorageView>(std::move(tmp_lengths));
}

if (layer_attention) {
alignment_heads.emplace_back(dtype, device);
ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back());
(*_layers[l])(layer_in_chunk,
input_lengths_mask.get(),
memory,
memory_lengths_mask.get(),
cached_self_attn_keys,
cached_self_attn_values,
cached_attn_keys,
cached_attn_values,
layer_out,
layer_attention.get(),
input_padder.get(),
memory_padder.get(),
return_normalized_attention(),
&position_bias,
offset);
layer_in_chunk = std::move(layer_out);

if (layer_attention) {
alignment_heads.emplace_back(dtype, device);
ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back());
}
}
layer_in = std::move(layer_in_chunk);
}

if (step == 0) {
Expand Down
Loading

0 comments on commit 120746e

Please sign in to comment.