Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sliding window + chunking input for mistral model #1524

Merged
merged 3 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ set(SOURCES
src/ops/bias_add.cc
src/ops/bias_add_cpu.cc
src/ops/concat.cc
src/ops/concat_split_cpu.cc
src/ops/concat_split_slide_cpu.cc
src/ops/conv1d.cc
src/ops/conv1d_cpu.cc
src/ops/cos.cc
Expand Down Expand Up @@ -168,6 +168,7 @@ set(SOURCES
src/ops/softmax.cc
src/ops/softmax_cpu.cc
src/ops/split.cc
src/ops/slide.cc
src/ops/sub.cc
src/ops/swish.cc
src/ops/tanh.cc
Expand Down Expand Up @@ -505,7 +506,7 @@ if (WITH_CUDA)
src/cuda/utils.cc
src/ops/alibi_add_gpu.cu
src/ops/bias_add_gpu.cu
src/ops/concat_split_gpu.cu
src/ops/concat_split_slide_gpu.cu
src/ops/conv1d_gpu.cu
src/ops/dequantize_gpu.cu
src/ops/gather_gpu.cu
Expand Down
3 changes: 2 additions & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace ctranslate2 {
const Padder* queries_padder = nullptr,
const Padder* values_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr) const;
StorageView* position_bias = nullptr,
dim_t offset = 0) const;

bool has_positional_embeddings() const {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
Expand Down
4 changes: 3 additions & 1 deletion include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ namespace ctranslate2 {
const Padder* input_padder = nullptr,
const Padder* memory_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr) const;
StorageView* position_bias = nullptr,
dim_t offset = 0) const;

DataType output_type() const override {
return _ff.output_type();
Expand Down Expand Up @@ -209,6 +210,7 @@ namespace ctranslate2 {
std::vector<std::vector<dim_t>> _alignment_heads;
bool _average_alignment_heads;
Dense _proj;
const dim_t _sliding_window;
};

}
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
#include "median_filter.h"
#include "rotary.h"
#include "alibi_add.h"
#include "slide.h"
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/slide.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class Slide : public Op {
public:
Slide(dim_t axis, const dim_t& index, const dim_t& size, bool no_copy = false);

void operator()(const StorageView& input, StorageView& output) const;
private:
dim_t _axis;
dim_t _index;
dim_t _size;
bool _no_copy;

void check_arguments() const;

template <Device D, typename T>
void compute(const StorageView& input, StorageView& output, const dim_t& index) const;
};

}
}
3 changes: 3 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def __init__(
self.alibi = alibi
self.alibi_use_positive_positions = alibi_use_positive_positions
self.scale_alibi = scale_alibi
if sliding_window is not None:
self.sliding_window = np.dtype("int32").type(sliding_window)
if (
not relative_position
and not relative_attention_bias
Expand Down Expand Up @@ -225,6 +227,7 @@ def __init__(
relative_attention_bias=relative_attention_bias,
rms_norm=rms_norm,
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
)
self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm)

Expand Down
28 changes: 23 additions & 5 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ namespace ctranslate2 {
const Padder* queries_padder,
const Padder* values_padder,
bool return_normalized_attention,
StorageView* position_bias) const {
StorageView* position_bias,
dim_t offset) const {
PROFILE("MultiHeadAttention");
const Device device = queries.device();
const DataType dtype = queries.dtype();
Expand All @@ -449,6 +450,8 @@ namespace ctranslate2 {

dim_t beam_size = 1;

bool prefilling = (_sliding_window > 0 && values_lengths);

if (!_self_attention) {
queries_proj = std::move(fused_proj);

Expand Down Expand Up @@ -507,10 +510,6 @@ namespace ctranslate2 {
}

if (_rotary_embeddings) {
const dim_t offset = (cached_keys && !cached_keys->empty()
? cached_keys->dim(_cache_time_dim)
: 0);

if (_merge_time_and_head_dims) {
queries_proj.reshape({queries_proj.dim(0), -1, _d_model});
split_heads(queries_proj, _num_heads);
Expand All @@ -536,6 +535,15 @@ namespace ctranslate2 {
concat_op({&tmp, &keys_proj}, *cached_keys);
tmp = std::move(*cached_values);
concat_op({&tmp, &values_proj}, *cached_values);

if (!prefilling && _sliding_window > 0 && cached_keys->shape()[2] > _sliding_window) {
// only for generation
const ops::Slide slide_op(2, 1, cached_keys->shape()[2] - 1);
slide_op(*cached_keys, tmp);
*cached_keys = std::move(tmp);
slide_op(*cached_values, tmp);
*cached_values = std::move(tmp);
}
}
}
}
Expand Down Expand Up @@ -564,6 +572,16 @@ namespace ctranslate2 {
_alibi,
position_bias);

if (prefilling && cached_keys->shape()[2] > _sliding_window) {
// set only last sliding_window tokens to cached_keys and cached_values after computing attention
const ops::Slide slide_op(2, cached_keys->shape()[2] - _sliding_window, _sliding_window);
StorageView tmp(dtype, device);
slide_op(*cached_keys, tmp);
*cached_keys = std::move(tmp);
slide_op(*cached_values, tmp);
*cached_values = std::move(tmp);
}

if (_merge_time_and_head_dims) {
context.reshape(queries.shape());
if (queries_padder)
Expand Down
134 changes: 91 additions & 43 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ namespace ctranslate2 {
const Padder* input_padder,
const Padder* memory_padder,
bool return_normalized_attention,
StorageView* position_bias) const {
StorageView* position_bias,
dim_t offset) const {
PROFILE("TransformerDecoderLayer");

const DataType dtype = input.dtype();
Expand Down Expand Up @@ -149,7 +150,8 @@ namespace ctranslate2 {
input_padder,
input_padder,
true,
position_bias);
position_bias,
offset);

if (_post_attention_layer_norm)
(*_post_attention_layer_norm)(input, hidden);
Expand All @@ -172,7 +174,8 @@ namespace ctranslate2 {
input_padder,
input_padder,
true,
position_bias);
position_bias,
offset);

StorageView context(dtype, device);
if (_encoder_attention) {
Expand Down Expand Up @@ -330,7 +333,8 @@ namespace ctranslate2 {
? nullptr
: build_position_encoder(model, scope + "/position_encodings", _embeddings))
, _with_encoder_attention(_layers.front()->has_cross_attention())
, _proj(model, scope + "/projection") {
, _proj(model, scope + "/projection")
, _sliding_window(model.get_attribute_with_default<int32_t>(scope + "/sliding_window", 0)) {

dim_t alignment_layer = (
model.get_attribute_with_default<int32_t>(scope + "/alignment_layer", -1));
Expand Down Expand Up @@ -467,7 +471,13 @@ namespace ctranslate2 {
(*_layernorm_embedding)(layer_in, layer_in);

const dim_t batch_size = layer_in.dim(0);
const dim_t max_time = layer_in.dim(1);
dim_t max_time;

if (_sliding_window > 0 && layer_in.dim(1) > _sliding_window) {
max_time = _sliding_window;
} else
max_time = layer_in.dim(1);

const bool allow_padding_removal = Padder::allow_padding_removal(_device, _compute_type);

std::unique_ptr<const Padder> input_padder;
Expand All @@ -479,14 +489,14 @@ namespace ctranslate2 {
lengths = input_lengths.get();
}

bool multi_query = _layers.front()->get_self_attention().multi_query();

if (lengths) {
if (allow_padding_removal) {
input_padder = std::make_unique<Padder>(*lengths, max_time);
input_padder->remove_padding(layer_in);
}

const bool multi_query = _layers.front()->get_self_attention().multi_query();

StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
*lengths,
_num_heads,
Expand Down Expand Up @@ -531,47 +541,85 @@ namespace ctranslate2 {

StorageView position_bias(dtype, device);

for (size_t l = 0; l < _layers.size(); ++l) {
StorageView* cached_self_attn_keys = nullptr;
StorageView* cached_self_attn_values = nullptr;
StorageView* cached_attn_keys = nullptr;
StorageView* cached_attn_values = nullptr;

if (step >= 0) {
const std::string l_str = std::to_string(l);
cached_self_attn_keys = &state.at("self_keys_" + l_str);
cached_self_attn_values = &state.at("self_values_" + l_str);
if (_with_encoder_attention) {
cached_attn_keys = &state.at("memory_keys_" + l_str);
cached_attn_values = &state.at("memory_values_" + l_str);
}
std::vector<StorageView> layer_ins;

while (true) {
dim_t prompt_size = layer_in.dim(1);
if (_sliding_window == 0 || prompt_size <= _sliding_window) {
layer_ins.push_back(std::move(layer_in));
break;
}
if (layer_in.dim(1) > _sliding_window) {
StorageView tmp(dtype, device);
const ops::Split split_op(1, {_sliding_window, prompt_size - _sliding_window});
split_op(layer_in, tmp, layer_in);
layer_ins.push_back(std::move(tmp));
}
}

std::unique_ptr<StorageView> heads_to_select = get_layer_alignment_heads(l, batch_size);
std::unique_ptr<StorageView> layer_attention;
if (attention && heads_to_select)
layer_attention = std::make_unique<StorageView>(dtype, device);
for (size_t i = 0; i < layer_ins.size(); ++i) {
auto layer_in_chunk = layer_ins[i];
for (size_t l = 0; l < _layers.size(); ++l) {
StorageView* cached_self_attn_keys = nullptr;
StorageView* cached_self_attn_values = nullptr;
StorageView* cached_attn_keys = nullptr;
StorageView* cached_attn_values = nullptr;

if (step >= 0) {
const std::string l_str = std::to_string(l);
cached_self_attn_keys = &state.at("self_keys_" + l_str);
cached_self_attn_values = &state.at("self_values_" + l_str);
if (_with_encoder_attention) {
cached_attn_keys = &state.at("memory_keys_" + l_str);
cached_attn_values = &state.at("memory_values_" + l_str);
}
}

(*_layers[l])(layer_in,
input_lengths_mask.get(),
memory,
memory_lengths_mask.get(),
cached_self_attn_keys,
cached_self_attn_values,
cached_attn_keys,
cached_attn_values,
layer_out,
layer_attention.get(),
input_padder.get(),
memory_padder.get(),
return_normalized_attention(),
&position_bias);
layer_in = std::move(layer_out);
std::unique_ptr<StorageView> heads_to_select = get_layer_alignment_heads(l, batch_size);
std::unique_ptr<StorageView> layer_attention;
if (attention && heads_to_select)
layer_attention = std::make_unique<StorageView>(dtype, device);

dim_t offset = _sliding_window * i + step;
if (i > 0) {
auto max_tokens = _sliding_window + layer_in_chunk.dim(1);
StorageView tmp_lengths = StorageView(Shape{layer_in_chunk.dim(0)}, int32_t(max_tokens), device);
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
tmp_lengths,
_num_heads,
max_tokens,
/*mask_future=*/true,
multi_query);

const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk.dim(1));
// reuse tmp_lengths
slide_lengths_op(lengths_mask, tmp_lengths);
input_lengths_mask = std::make_unique<StorageView>(std::move(tmp_lengths));
}

if (layer_attention) {
alignment_heads.emplace_back(dtype, device);
ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back());
(*_layers[l])(layer_in_chunk,
input_lengths_mask.get(),
memory,
memory_lengths_mask.get(),
cached_self_attn_keys,
cached_self_attn_values,
cached_attn_keys,
cached_attn_values,
layer_out,
layer_attention.get(),
input_padder.get(),
memory_padder.get(),
return_normalized_attention(),
&position_bias,
offset);
layer_in_chunk = std::move(layer_out);

if (layer_attention) {
alignment_heads.emplace_back(dtype, device);
ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back());
}
}
layer_in = std::move(layer_in_chunk);
}

if (step == 0) {
Expand Down
Loading