Skip to content

Commit

Permalink
add test tensor function
Browse files Browse the repository at this point in the history
  • Loading branch information
rocking5566 committed Nov 7, 2024
1 parent 23cb26b commit 2c5b84c
Showing 1 changed file with 125 additions and 52 deletions.
177 changes: 125 additions & 52 deletions csrc/flash_attn_ck/mha_fwd_kvcache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
#include "flash_common.hpp"

#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "rotary.hpp"

fmha_fwd_appendkv_traits get_ck_fmha_fwd_appendkv_traits(std::string dtype,
int head_size,
int rotary_dim,
bool is_rotary_interleaved)
int head_size,
int rotary_dim,
bool is_rotary_interleaved)
{
rope_enum rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
: rope_enum::half_rotated)
Expand All @@ -19,7 +20,7 @@ fmha_fwd_appendkv_traits get_ck_fmha_fwd_appendkv_traits(std::string dtype,
return fmha_fwd_appendkv_traits{head_size,
head_size,
dtype,
true, // is_v_rowmajor
true, // is_v_rowmajor
rope_type};
}

Expand All @@ -33,13 +34,25 @@ fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask,
head_size,
dtype,
false, // is_group_mode
true, // is_v_rowmajor
true, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
false}; // do_fp8_static_quant
}

void *test_tensor(const at::Tensor t)
{
int len = 1;

for (int i = 0; i < t.sizes().size(); ++i)
len *= t.sizes()[i];

std::vector<ck_tile::half_t> t_host(len);
void *t_ptr = t.data_ptr();
HIP_CHECK_ERROR(hipMemcpy(t_host.data(), t_ptr, len, hipMemcpyDeviceToHost));
}

fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,
const int seqlen_q,
const int seqlen_knew,
Expand Down Expand Up @@ -80,6 +93,14 @@ fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,
args.vnew_ptr = vnew.data_ptr();
args.seqlen_k_ptr = seqlens_k_.has_value() ? seqlens_k_.value().data_ptr() : nullptr;

test_tensor(q);
test_tensor(kcache);
test_tensor(knew);
test_tensor(vcache);
test_tensor(vnew);
if (seqlens_k_.has_value())
test_tensor(seqlens_k_.value());

args.seqlen_q = seqlen_q;
args.seqlen_knew = seqlen_knew;
args.batch = b;
Expand All @@ -99,6 +120,8 @@ fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,
args.block_table_ptr = block_table.data_ptr();
args.batch_stride_block_table = block_table.stride(0);
args.page_block_size = page_block_size;

test_tensor(block_table);
}
else
{
Expand All @@ -107,8 +130,7 @@ fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,
args.page_block_size = 0;
}

args.cache_batch_idx = cache_batch_idx_.has_value() ?
reinterpret_cast<int *>(cache_batch_idx_.value().data_ptr()) : nullptr;
args.cache_batch_idx = cache_batch_idx_.has_value() ? reinterpret_cast<int *>(cache_batch_idx_.value().data_ptr()) : nullptr;

args.batch_stride_q = q.stride(0);
args.stride_q = q.stride(1);
Expand Down Expand Up @@ -171,6 +193,9 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,
args.q_ptr = q.data_ptr();
args.k_ptr = k.data_ptr();
args.v_ptr = v.data_ptr();

HIP_CHECK_ERROR(hipMemset(args.q_ptr, 0, mMemSize));

args.bias_ptr = nullptr;
args.lse_acc_ptr = lse_acc.data_ptr();
args.o_acc_ptr = out_acc.data_ptr();
Expand Down Expand Up @@ -243,13 +268,15 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,
args.nhead_stride_o_acc = out_acc.stride(2);
args.stride_o_acc = out_acc.stride(3);

if (has_lse) {
if (has_lse)
{
args.lse_ptr = lse.data_ptr();
args.batch_stride_lse = lse.stride(0);
args.nhead_stride_lse = lse.stride(1);
}

if (alibi_slopes_.has_value()) {
if (alibi_slopes_.has_value())
{
auto alibi_slopes = alibi_slopes_.value();
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
Expand All @@ -266,19 +293,19 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,
}

std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<const at::Tensor> & /*leftpad_k_*/, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<const at::Tensor> & /*leftpad_k_*/, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
Expand All @@ -295,15 +322,18 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";

CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
CHECK_DEVICE(q);
CHECK_DEVICE(kcache);
CHECK_DEVICE(vcache);

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");

at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
if (paged_KV) {
if (paged_KV)
{
TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
block_table = block_table_.value();
CHECK_DEVICE(block_table);
Expand All @@ -330,20 +360,29 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

// causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; }
if (seqlen_q == 1 && !alibi_slopes_.has_value())
{
is_causal = false;
}
if (is_causal)
{
window_size_right = 0;
}

mask_info mask;
if (is_causal) {
if (is_causal)
{
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
window_size_right = 0;
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
}
else if (window_size_left == -1 && window_size_right == -1) {
else if (window_size_left == -1 && window_size_right == -1)
{
mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
}
else {
else
{
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
Expand All @@ -352,50 +391,70 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
if (seqlenq_ngroups_swapped) {
if (seqlenq_ngroups_swapped)
{
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
}

if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
if (window_size_left >= seqlen_k)
{
window_size_left = -1;
}
if (window_size_right >= seqlen_k)
{
window_size_right = -1;
}

CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
if (!paged_KV) {
if (!paged_KV)
{
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
} else {
}
else
{
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}

at::Tensor q_padded, kcache_padded, vcache_padded;
if (head_size_og % 8 != 0) {
if (head_size_og % 8 != 0)
{
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
}
else
{
q_padded = q;
kcache_padded = kcache;
vcache_padded = vcache;
}

at::Tensor out;
if (out_.has_value()) {
if (out_.has_value())
{
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
if (head_size_og % 8 != 0)
{
out = torch::empty_like(q_padded);
}
}
else
{
out = torch::empty_like(q_padded);
}

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
auto round_multiple = [](int x, int m)
{ return (x + m - 1) / m * m; };
const int head_size_8x = round_multiple(head_size_og, 8);

// Otherwise the kernel will be launched from cuda:0 device
Expand All @@ -410,30 +469,36 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz

int seqlen_knew = 0;
at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
if (k_.has_value())
{
TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
k = k_.value();
v = v_.value();
TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(k);
CHECK_DEVICE(v);
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
seqlen_knew = k.size(1);
CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
if (head_size_og % 8 != 0) {
if (head_size_og % 8 != 0)
{
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
}
else
{
k_padded = k;
v_padded = v;
}
}

if (seqlens_k_.has_value()) {
if (seqlens_k_.has_value())
{
auto seqlens_k = seqlens_k_.value();
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
Expand All @@ -442,7 +507,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
}

int rotary_dim = 0;
if (rotary_cos_.has_value()) {
if (rotary_cos_.has_value())
{
TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
auto rotary_cos = rotary_cos_.value();
CHECK_DEVICE(rotary_cos);
Expand All @@ -463,8 +529,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
}


if (cache_batch_idx_.has_value()) {
if (cache_batch_idx_.has_value())
{
auto cache_batch_idx = cache_batch_idx_.value();
CHECK_DEVICE(cache_batch_idx);
CHECK_CONTIGUOUS(cache_batch_idx);
Expand All @@ -482,7 +548,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
auto stream = at::cuda::getCurrentCUDAStream().stream();
ck_tile::stream_config stream_config{stream};

if (seqlen_knew > 0 || rotary_dim > 0) {
if (seqlen_knew > 0 || rotary_dim > 0)
{
auto appendkv_traits =
get_ck_fmha_fwd_appendkv_traits(q_dtype_str, head_size_8x, rotary_dim, is_rotary_interleaved);

Expand Down Expand Up @@ -549,18 +616,24 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz

fmha_fwd_splitkv(splitkv_traits, splitkv_args, stream_config);

if (head_size_og % 8 != 0) {
if (head_size_og % 8 != 0)
{
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
if (out_.has_value()) { out_.value().copy_(out); }
if (k_.has_value()) {
if (out_.has_value())
{
out_.value().copy_(out);
}
if (k_.has_value())
{
// It's expensive to copy the KV cache here for the case where head size not divisible by 8,
// but we don't expect to get this case in practice. This is just so that the code works for that case.
kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
}
}

if (seqlenq_ngroups_swapped) {
if (seqlenq_ngroups_swapped)
{
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
Expand Down

0 comments on commit 2c5b84c

Please sign in to comment.