Skip to content

Commit

Permalink
Support variable length of page attention
Browse files Browse the repository at this point in the history
  • Loading branch information
rocking5566 committed Nov 21, 2024
1 parent 63a64b5 commit 740da39
Showing 1 changed file with 259 additions and 44 deletions.
303 changes: 259 additions & 44 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
false}; // do_fp8_static_quant
}

fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_lse,
bool enable_alibi)
{
return fmha_fwd_splitkv_traits{head_size,
head_size,
dtype,
true, // is_group_mode
true, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
false}; // do_fp8_static_quant
}

fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
bool has_dropout_randval,
const mask_info &mask,
Expand Down Expand Up @@ -142,6 +159,140 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
drop_seed_offset};
}

fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
const mask_info &mask,
const int b,
const int max_seqlen_q,
const int h,
const int h_k,
const int d,
const int page_block_size,
const int num_splits,
float softmax_scale,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor seqlens_q,
const at::Tensor seqlens_k,
c10::optional<at::Tensor> &block_table_,
c10::optional<at::Tensor> &alibi_slopes_,
at::Tensor out,
at::Tensor lse,
at::Tensor lse_acc,
at::Tensor out_acc)
{
// q: (total_q, nheads, d)
// k: (total_k, nheads_k, d)
// v: (total_k, nheads_k, d)
// o: (total_q, nheads, d)

// alibi_slopes:(batch_size, nheads) or (nhead)
// lse: (nheads, total_q)
// lse_acc: (nheads, split, total_q)
// o_acc: (nheads, split, total_q, d)

ck_tile::index_t total_q = q.size(0);
ck_tile::index_t total_k = k.size(0);

fmha_fwd_splitkv_args args;
args.q_ptr = q.data_ptr();
args.k_ptr = k.data_ptr();
args.v_ptr = v.data_ptr();
args.bias_ptr = nullptr;
args.lse_acc_ptr = lse_acc.data_ptr();
args.o_acc_ptr = out_acc.data_ptr();
args.lse_ptr = nullptr;
args.o_ptr = out.data_ptr();

if (block_table_.has_value())
{
auto block_table = block_table_.value();
args.block_table_ptr = block_table.data_ptr();
args.batch_stride_block_table = block_table.stride(0);
args.page_block_size = page_block_size;
}
else
{
args.block_table_ptr = nullptr;
args.batch_stride_block_table = 0;
args.page_block_size = 0;
}

args.cache_batch_idx = nullptr;

args.seqstart_q_ptr = seqlens_q.data_ptr();
args.seqstart_k_ptr = seqlens_k.data_ptr();
args.seqlen_k_ptr = nullptr;

args.seqlen_q = total_q;
args.seqlen_k = total_k;
args.batch = b;
args.max_seqlen_q = max_seqlen_q;
args.hdim_q = d;
args.hdim_v = d;
args.nhead_q = h;
args.nhead_k = h_k;
args.num_splits = num_splits;

args.scale_s = softmax_scale;
args.scale_p = 1;
args.scale_o = 1;

args.batch_stride_q = 0;
args.stride_q = q.stride(0);
args.nhead_stride_q = q.stride(1);

args.batch_stride_k = 0;
args.stride_k = k.stride(0);
args.nhead_stride_k = k.stride(1);

args.batch_stride_v = 0;
args.stride_v = v.stride(0);
args.nhead_stride_v = v.stride(1);

args.batch_stride_o = 0;
args.stride_o = out.stride(0);
args.nhead_stride_o = out.stride(1);

args.batch_stride_bias = 0;
args.stride_bias = 0;
args.nhead_stride_bias = 0;

args.batch_stride_lse = 0;
args.nhead_stride_lse = 0;

args.batch_stride_lse_acc = 0;
args.nhead_stride_lse_acc = lse_acc.stride(0);
args.split_stride_lse_acc = lse_acc.stride(1);

args.batch_stride_o_acc = 0;
args.nhead_stride_o_acc = out_acc.stride(0);
args.split_stride_o_acc = out_acc.stride(1);
args.stride_o_acc = out_acc.stride(2);

if (has_lse) {
args.lse_ptr = lse.data_ptr();
args.batch_stride_lse = 0;
args.nhead_stride_lse = lse.stride(0);
}

if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
args.bias_ptr = alibi_slopes.data_ptr();
args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}

args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);

return args;
}

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
Expand Down Expand Up @@ -180,9 +331,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);

// TODO - Support paged_KV
at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet");
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
Expand All @@ -195,10 +351,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size = sizes[2];
const int num_heads_k = k.size(1);
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);

const int max_num_blocks_per_seq = 0;
const int num_blocks = 0;
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128");

if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case

Expand All @@ -207,7 +365,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
// H/t Daniel Haziza

const int total_q = q.size(0);
const int total_k = k.size(0);

TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256");
Expand Down Expand Up @@ -235,11 +392,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
}

CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
if (!paged_KV) {
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}

CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);

at::Tensor out;
if (out_.has_value()) {
out = out_.value();
Expand All @@ -259,6 +423,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
auto opts = q.options();
bool has_lse = true;
bool has_dropout = p_dropout > 0.0f;
if (has_dropout)
TORCH_CHECK(!paged_KV, "Paged KV does not support dropout");

at::Tensor softmax_lse;
// TODO - check gradient, only training require lse
Expand All @@ -280,6 +446,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
if (return_dropout_randval) {p.zero_();}
}

int num_splits = 1;
num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits);
TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");

auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat));
auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size}, opts.dtype(at::kFloat));

int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
Expand All @@ -295,44 +469,85 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
}

if (max_seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};

auto traits =
get_ck_fmha_varlen_fwd_traits(
mask,
q_dtype_str,
head_size,
has_dropout,
has_lse,
alibi_slopes_.has_value());

auto args =
get_ck_fmha_varlen_fwd_args(
has_lse,
return_dropout_randval,
mask,
batch_size,
max_seqlen_q,
num_heads,
num_heads_k,
head_size,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
out,
softmax_lse,
p,
softmax_scale,
p_dropout,
drop_seed_offset);

float t = fmha_fwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
if (paged_KV)
{
auto traits =
get_ck_fmha_varlen_fwd_splitkv_traits(
mask,
q_dtype_str,
head_size,
has_lse,
alibi_slopes_.has_value());

auto args =
get_ck_fmha_varlen_fwd_splitkv_args(
has_lse,
mask,
batch_size,
max_seqlen_q,
num_heads,
num_heads_k,
head_size,
page_block_size,
num_splits,
softmax_scale,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
block_table_,
alibi_slopes_,
out,
softmax_lse,
softmax_lse_accum,
out_accum);

float t = fmha_fwd_splitkv(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd_splitkv");
}
else
{
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);

auto traits =
get_ck_fmha_varlen_fwd_traits(
mask,
q_dtype_str,
head_size,
has_dropout,
has_lse,
alibi_slopes_.has_value());

auto args =
get_ck_fmha_varlen_fwd_args(
has_lse,
return_dropout_randval,
mask,
batch_size,
max_seqlen_q,
num_heads,
num_heads_k,
head_size,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
out,
softmax_lse,
p,
softmax_scale,
p_dropout,
drop_seed_offset);

float t = fmha_fwd(traits, args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
}
}
else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
Expand Down

0 comments on commit 740da39

Please sign in to comment.